diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index fb842b77bf..f915823245 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -150,19 +150,15 @@ std::vector getInputs() { (1 << 20), /* max_leaves */ 1.f, /* max_features */ 32, /* n_bins */ - 1, /* split_algo */ 3, /* min_samples_leaf */ 3, /* min_samples_split */ 0.0f, /* min_impurity_decrease */ - true, /* bootstrap_features */ true, /* bootstrap */ 1, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ ML::CRITERION::MSE, /* split_criterion */ - false, /* quantile_per_tree */ 8, /* n_streams */ - false, /* use_experimental_backend */ 128 /* max_batch_size */ ); diff --git a/cpp/bench/sg/rf_classifier.cu b/cpp/bench/sg/rf_classifier.cu index 021e88acb7..9f06450675 100644 --- a/cpp/bench/sg/rf_classifier.cu +++ b/cpp/bench/sg/rf_classifier.cu @@ -86,19 +86,15 @@ std::vector getInputs() { (1 << 20), /* max_leaves */ 0.3, /* max_features */ 32, /* n_bins */ - 1, /* split_algo */ 3, /* min_samples_leaf */ 3, /* min_samples_split */ 0.0f, /* min_impurity_decrease */ - true, /* bootstrap_features */ true, /* bootstrap */ 500, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ ML::CRITERION::GINI, /* split_criterion */ - false, /* quantile_per_tree */ 8, /* n_streams */ - false, /* use_experimental_backend */ 128 /* max_batch_size */ ); diff --git a/cpp/bench/sg/rf_regressor.cu b/cpp/bench/sg/rf_regressor.cu index e50ea1e2f7..24f08f0635 100644 --- a/cpp/bench/sg/rf_regressor.cu +++ b/cpp/bench/sg/rf_regressor.cu @@ -88,19 +88,15 @@ std::vector getInputs() { (1 << 20), /* max_leaves */ 0.3, /* max_features */ 32, /* n_bins */ - 1, /* split_algo */ 3, /* min_samples_leaf */ 3, /* min_samples_split */ 0.0f, /* min_impurity_decrease */ - true, /* bootstrap_features */ true, /* bootstrap */ 500, /* n_trees */ 1.f, /* max_samples */ 1234ULL, /* seed */ ML::CRITERION::MSE, /* split_criterion */ - false, /* quantile_per_tree */ 8, /* n_streams */ - false, /* use_experimental_backend */ 128 /* max_batch_size */ ); std::vector dim_info = {{500000, 500, 400}}; diff --git a/cpp/include/cuml/ensemble/randomforest.hpp b/cpp/include/cuml/ensemble/randomforest.hpp index 9321b93307..a663d16e0e 100644 --- a/cpp/include/cuml/ensemble/randomforest.hpp +++ b/cpp/include/cuml/ensemble/randomforest.hpp @@ -181,12 +181,10 @@ RF_metrics score(const raft::handle_t& user_handle, int verbosity = CUML_LEVEL_INFO); RF_params set_rf_params(int max_depth, int max_leaves, float max_features, - int n_bins, int split_algo, int min_samples_leaf, - int min_samples_split, float min_impurity_decrease, - bool bootstrap_features, bool bootstrap, int n_trees, - float max_samples, uint64_t seed, - CRITERION split_criterion, bool quantile_per_tree, - int cfg_n_streams, bool use_experimental_backend, + int n_bins, int min_samples_leaf, int min_samples_split, + float min_impurity_decrease, bool bootstrap, + int n_trees, float max_samples, uint64_t seed, + CRITERION split_criterion, int cfg_n_streams, int max_batch_size); // ----------------------------- Regression ----------------------------------- // diff --git a/cpp/include/cuml/tree/algo_helper.h b/cpp/include/cuml/tree/algo_helper.h index 28d00ce010..28b4ac0e5d 100644 --- a/cpp/include/cuml/tree/algo_helper.h +++ b/cpp/include/cuml/tree/algo_helper.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,6 @@ #pragma once namespace ML { -enum SPLIT_ALGO { - HIST, - GLOBAL_QUANTILE, - SPLIT_ALGO_END, -}; - enum CRITERION { GINI, ENTROPY, diff --git a/cpp/include/cuml/tree/decisiontree.hpp b/cpp/include/cuml/tree/decisiontree.hpp index a66a3a1c1b..c483783679 100644 --- a/cpp/include/cuml/tree/decisiontree.hpp +++ b/cpp/include/cuml/tree/decisiontree.hpp @@ -44,10 +44,6 @@ struct DecisionTreeParams { * Number of bins used by the split algorithm. */ int n_bins; - /** - * The split algorithm: HIST or GLOBAL_QUANTILE. - */ - int split_algo; /** * The minimum number of samples (rows) in each leaf node. */ @@ -57,16 +53,7 @@ struct DecisionTreeParams { */ int min_samples_split; /** - * Control bootstrapping for features. If features are drawn with or without replacement - */ - bool bootstrap_features; - /** - * Whether a quantile needs to be computed for individual trees in RF. - * Default: compute quantiles once per RF. Only affects GLOBAL_QUANTILE split_algo. - */ - bool quantile_per_tree; - /** - * Node split criterion. GINI and Entropy for classification, MSE or MAE for regression. + * Node split criterion. GINI and Entropy for classification, MSE for regression. */ CRITERION split_criterion; /** @@ -79,14 +66,6 @@ struct DecisionTreeParams { * used only for batched-level algo */ int max_batch_size; - /** - * If set to true and following conditions are also met, experimental decision - * tree training implementation would be used: - * split_algo = 1 (GLOBAL_QUANTILE) - * max_features = 1.0 (Feature sub-sampling disabled) - * quantile_per_tree = false (No per tree quantile computation) - */ - bool use_experimental_backend; }; /** @@ -96,33 +75,23 @@ struct DecisionTreeParams { * @param[in] cfg_max_leaves: maximum leaves; default -1 * @param[in] cfg_max_features: maximum number of features; default 1.0f * @param[in] cfg_n_bins: number of bins; default 8 - * @param[in] cfg_split_algo: split algorithm; default SPLIT_ALGO::HIST * @param[in] cfg_min_samples_leaf: min. rows in each leaf node; default 1 * @param[in] cfg_min_samples_split: min. rows needed to split an internal node; * default 2 * @param[in] cfg_min_impurity_decrease: split a node only if its reduction in * impurity is more than this value - * @param[in] cfg_bootstrap_features: bootstrapping for features; default false * @param[in] cfg_split_criterion: split criterion; default CRITERION_END, * i.e., GINI for classification or MSE for regression - * @param[in] cfg_quantile_per_tree: compute quantile per tree; default false - * @param[in] cfg_use_experimental_backend: When set to true, experimental batched - * backend is used (provided other conditions are met). Default is - True. * @param[in] cfg_max_batch_size: Maximum number of nodes that can be processed in a batch. This is used only for batched-level algo. Default value 128. */ void set_tree_params(DecisionTreeParams ¶ms, int cfg_max_depth = -1, int cfg_max_leaves = -1, float cfg_max_features = 1.0f, - int cfg_n_bins = 8, int cfg_split_algo = SPLIT_ALGO::HIST, - int cfg_min_samples_leaf = 1, + int cfg_n_bins = 128, int cfg_min_samples_leaf = 1, int cfg_min_samples_split = 2, float cfg_min_impurity_decrease = 0.0f, - bool cfg_bootstrap_features = false, CRITERION cfg_split_criterion = CRITERION_END, - bool cfg_quantile_per_tree = false, - bool cfg_use_experimental_backend = true, int cfg_max_batch_size = 128); /** diff --git a/cpp/src/decisiontree/decisiontree.cu b/cpp/src/decisiontree/decisiontree.cu index 5eda766d58..349df605f9 100644 --- a/cpp/src/decisiontree/decisiontree.cu +++ b/cpp/src/decisiontree/decisiontree.cu @@ -14,7 +14,9 @@ * limitations under the License. */ +#include #include +#include #include #include "decisiontree_impl.cuh" @@ -29,59 +31,25 @@ namespace DecisionTree { * @param[in] cfg_max_leaves: maximum leaves; default -1 * @param[in] cfg_max_features: maximum number of features; default 1.0f * @param[in] cfg_n_bins: number of bins; default 8 - * @param[in] cfg_split_algo: split algorithm; default SPLIT_ALGO::HIST * @param[in] cfg_min_samples_leaf: min. rows in each leaf node; default 1 * @param[in] cfg_min_samples_split: min. rows needed to split an internal node; * default 2 - * @param[in] cfg_bootstrap_features: bootstrapping for features; default false * @param[in] cfg_split_criterion: split criterion; default CRITERION_END, * i.e., GINI for classification or MSE for regression - * @param[in] cfg_quantile_per_tree: compute quantile per tree; default false - * @param[in] cfg_use_experimental_backend: Switch to using experimental - backend; default false * @param[in] cfg_max_batch_size: batch size for experimental backend */ void set_tree_params(DecisionTreeParams ¶ms, int cfg_max_depth, int cfg_max_leaves, float cfg_max_features, int cfg_n_bins, - int cfg_split_algo, int cfg_min_samples_leaf, - int cfg_min_samples_split, float cfg_min_impurity_decrease, - bool cfg_bootstrap_features, CRITERION cfg_split_criterion, - bool cfg_quantile_per_tree, - bool cfg_use_experimental_backend, - int cfg_max_batch_size) { - if (cfg_use_experimental_backend) { - if (cfg_split_algo != SPLIT_ALGO::GLOBAL_QUANTILE) { - CUML_LOG_WARN( - "Experimental backend does not yet support histogram split algorithm"); - CUML_LOG_WARN( - "To use experimental backend set split_algo = 1 (GLOBAL_QUANTILE)"); - cfg_use_experimental_backend = false; - } - if (cfg_quantile_per_tree) { - CUML_LOG_WARN( - "Experimental backend does not yet support per tree quantile " - "computation"); - CUML_LOG_WARN( - "To use experimental backend set quantile_per_tree = false"); - cfg_use_experimental_backend = false; - } - if (!cfg_use_experimental_backend) { - CUML_LOG_WARN( - "Not using the experimental backend due to above mentioned reason(s)"); - } - } - + int cfg_min_samples_leaf, int cfg_min_samples_split, + float cfg_min_impurity_decrease, + CRITERION cfg_split_criterion, int cfg_max_batch_size) { params.max_depth = cfg_max_depth; params.max_leaves = cfg_max_leaves; params.max_features = cfg_max_features; params.n_bins = cfg_n_bins; - params.split_algo = cfg_split_algo; params.min_samples_leaf = cfg_min_samples_leaf; params.min_samples_split = cfg_min_samples_split; - params.bootstrap_features = cfg_bootstrap_features; params.split_criterion = cfg_split_criterion; - params.quantile_per_tree = cfg_quantile_per_tree; - params.use_experimental_backend = cfg_use_experimental_backend; params.min_impurity_decrease = cfg_min_impurity_decrease; params.max_batch_size = cfg_max_batch_size; } @@ -95,10 +63,6 @@ void validity_check(const DecisionTreeParams params) { params.max_features); ASSERT((params.n_bins > 0), "Invalid n_bins %d", params.n_bins); ASSERT((params.split_criterion != 3), "MAE not supported."); - ASSERT((params.split_algo >= 0) && - (params.split_algo < SPLIT_ALGO::SPLIT_ALGO_END), - "split_algo value %d outside permitted [0, %d) range", - params.split_algo, SPLIT_ALGO::SPLIT_ALGO_END); ASSERT((params.min_samples_leaf >= 1), "Invalid value for min_samples_leaf %d. Should be >= 1.", params.min_samples_leaf); @@ -112,15 +76,10 @@ void print(const DecisionTreeParams params) { CUML_LOG_DEBUG("max_leaves: %d", params.max_leaves); CUML_LOG_DEBUG("max_features: %f", params.max_features); CUML_LOG_DEBUG("n_bins: %d", params.n_bins); - CUML_LOG_DEBUG("split_algo: %d", params.split_algo); CUML_LOG_DEBUG("min_samples_leaf: %d", params.min_samples_leaf); CUML_LOG_DEBUG("min_samples_split: %d", params.min_samples_split); - CUML_LOG_DEBUG("bootstrap_features: %d", params.bootstrap_features); CUML_LOG_DEBUG("split_criterion: %d", params.split_criterion); - CUML_LOG_DEBUG("quantile_per_tree: %d", params.quantile_per_tree); CUML_LOG_DEBUG("min_impurity_decrease: %f", params.min_impurity_decrease); - CUML_LOG_DEBUG("use_experimental_backend: %s", - params.use_experimental_backend ? "True" : "False"); CUML_LOG_DEBUG("max_batch_size: %d", params.max_batch_size); } @@ -159,21 +118,15 @@ void decisionTreeClassifierFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_classifier = std::make_shared>(); - std::unique_ptr> global_quantiles_buffer = - nullptr; - float *global_quantiles = nullptr; - - if (tree_params.use_experimental_backend) { - auto quantile_size = tree_params.n_bins * ncols; - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, - nrows, ncols, handle.get_device_allocator(), - handle.get_stream()); - } + auto quantile_size = tree_params.n_bins * ncols; + MLCommon::device_buffer global_quantiles_buffer( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + DecisionTree::computeQuantiles( + global_quantiles_buffer.data(), tree_params.n_bins, data, nrows, ncols, + handle.get_device_allocator(), handle.get_stream()); dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, tree, tree_params, seed, global_quantiles); + unique_labels, tree, tree_params, seed, + global_quantiles_buffer.data()); } void decisionTreeClassifierFit(const raft::handle_t &handle, @@ -185,21 +138,16 @@ void decisionTreeClassifierFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_classifier = std::make_shared>(); - std::unique_ptr> global_quantiles_buffer = - nullptr; - double *global_quantiles = nullptr; - if (tree_params.use_experimental_backend) { - auto quantile_size = tree_params.n_bins * ncols; - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, - nrows, ncols, handle.get_device_allocator(), - handle.get_stream()); - } + auto quantile_size = tree_params.n_bins * ncols; + MLCommon::device_buffer global_quantiles_buffer( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + DecisionTree::computeQuantiles( + global_quantiles_buffer.data(), tree_params.n_bins, data, nrows, ncols, + handle.get_device_allocator(), handle.get_stream()); dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, tree, tree_params, seed, global_quantiles); + unique_labels, tree, tree_params, seed, + global_quantiles_buffer.data()); } void decisionTreeClassifierPredict(const raft::handle_t &handle, @@ -234,21 +182,14 @@ void decisionTreeRegressorFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_regressor = std::make_shared>(); - std::unique_ptr> global_quantiles_buffer = - nullptr; - float *global_quantiles = nullptr; - - if (tree_params.use_experimental_backend) { - auto quantile_size = tree_params.n_bins * ncols; - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, - nrows, ncols, handle.get_device_allocator(), - handle.get_stream()); - } + auto quantile_size = tree_params.n_bins * ncols; + MLCommon::device_buffer global_quantiles( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + DecisionTree::computeQuantiles( + global_quantiles.data(), tree_params.n_bins, data, nrows, ncols, + handle.get_device_allocator(), handle.get_stream()); dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - tree, tree_params, seed, global_quantiles); + tree, tree_params, seed, global_quantiles.data()); } void decisionTreeRegressorFit(const raft::handle_t &handle, @@ -259,21 +200,15 @@ void decisionTreeRegressorFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_regressor = std::make_shared>(); - std::unique_ptr> global_quantiles_buffer = - nullptr; - double *global_quantiles = nullptr; - if (tree_params.use_experimental_backend) { - auto quantile_size = tree_params.n_bins * ncols; - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, - nrows, ncols, handle.get_device_allocator(), - handle.get_stream()); - } + auto quantile_size = tree_params.n_bins * ncols; + MLCommon::device_buffer global_quantiles( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + DecisionTree::computeQuantiles( + global_quantiles.data(), tree_params.n_bins, data, nrows, ncols, + handle.get_device_allocator(), handle.get_stream()); dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - tree, tree_params, seed, global_quantiles); + tree, tree_params, seed, global_quantiles.data()); } void decisionTreeRegressorPredict(const raft::handle_t &handle, diff --git a/cpp/src/decisiontree/decisiontree_impl.cuh b/cpp/src/decisiontree/decisiontree_impl.cuh index e60e488c7f..8efb4d10cf 100644 --- a/cpp/src/decisiontree/decisiontree_impl.cuh +++ b/cpp/src/decisiontree/decisiontree_impl.cuh @@ -28,10 +28,6 @@ #include #include "batched-levelalgo/builder.cuh" #include "decisiontree_impl.h" -#include "levelalgo/levelfunc_classifier.cuh" -#include "levelalgo/levelfunc_regressor.cuh" -#include "levelalgo/metric.cuh" -#include "memory.cuh" #include "quantile/quantile.h" #include "treelite_util.h" @@ -252,63 +248,6 @@ void DecisionTreeBase::print( get_node_text("", sparsetree, 0, false); } -/** - * @brief This function calls the relevant regression oir classification with input parameters. - * @tparam T: datatype of input data (float ot double) - * @tparam L: data type for labels (int type for classification, T type for regression). - * @param[out] sparsetree: This will be the generated Decision Tree - * @param[in] data: Input data - * @param[in] ncols: Original number of columns in the dataset - * @param[in] nrows: Original number of rows in dataset - * @param[in] labels: Labels of input dataset - * @param[in] rowids: List of selected rows for the tree building - * @param[in] n_sampled_rows: Number of rows after subsampling - * @param[in] unique_labels: Number of unique classes for calssification. Its set to 1 for regression - * @param[in] treeid: Tree id in case of building multiple tree from RF. - */ -template -void DecisionTreeBase::plant( - std::vector> &sparsetree, const T *data, const int ncols, - const int nrows, const L *labels, unsigned int *rowids, - const int n_sampled_rows, int unique_labels, const int treeid, - uint64_t seed) { - ML::PUSH_RANGE("DecisionTreeBase::plant @decisiontree_impl.cuh"); - dinfo.NLocalrows = nrows; - dinfo.NGlobalrows = nrows; - dinfo.Ncols = ncols; - n_unique_labels = unique_labels; - - if (tree_params.split_algo == SPLIT_ALGO::GLOBAL_QUANTILE && - tree_params.quantile_per_tree) { - preprocess_quantile(data, rowids, n_sampled_rows, ncols, dinfo.NLocalrows, - tree_params.n_bins, tempmem); - } - CUDA_CHECK(cudaStreamSynchronize( - tempmem->stream)); // added to ensure accurate measurement - ML::PUSH_RANGE("DecisionTreeBase::plant::bootstrapping features"); - //Bootstrap features - unsigned int *h_colids = tempmem->h_colids->data(); - // fill with ascending range of indices - std::iota(h_colids, h_colids + dinfo.Ncols, 0); - // if feature sampling, shuffle - if (tree_params.max_features != 1.f) { - // seed with treeid - srand(treeid * 1000); - std::random_shuffle(h_colids, h_colids + dinfo.Ncols, - [](int j) { return rand() % j; }); - } - ML::POP_RANGE(); - prepare_time = prepare_fit_timer.getElapsedSeconds(); - - total_temp_mem = tempmem->totalmem; - MLCommon::TimerCPU timer; - grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols, - tree_params.max_features, dinfo.NLocalrows, sparsetree, treeid, - tempmem); - train_time = timer.getElapsedSeconds(); - ML::POP_RANGE(); -} - template void DecisionTreeBase::predict(const raft::handle_t &handle, const TreeMetaDataNode *tree, @@ -377,14 +316,13 @@ void DecisionTreeBase::base_fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, std::vector> &sparsetree, - const int treeid, uint64_t seed, bool is_classifier, T *d_global_quantiles, - std::shared_ptr> in_tempmem) { + const int treeid, uint64_t seed, bool is_classifier, T *d_global_quantiles) { prepare_fit_timer.reset(); const char *CRITERION_NAME[] = {"GINI", "ENTROPY", "MSE", "MAE", "END"}; CRITERION default_criterion = (is_classifier) ? CRITERION::GINI : CRITERION::MSE; CRITERION last_criterion = - (is_classifier) ? CRITERION::ENTROPY : CRITERION::MAE; + (is_classifier) ? CRITERION::ENTROPY : CRITERION::MSE; validity_check(tree_params); if (tree_params.n_bins > n_sampled_rows) { @@ -404,41 +342,14 @@ void DecisionTreeBase::base_fit( "Unsupported criterion %s\n", CRITERION_NAME[tree_params.split_criterion]); - if (!tree_params.use_experimental_backend) { - // Only execute for level backend as temporary memory is unused in batched - // backend. - if (in_tempmem != nullptr) { - tempmem = in_tempmem; - } else { - tempmem = std::make_shared>( - device_allocator_in, host_allocator_in, stream_in, nrows, ncols, - unique_labels, tree_params); - tree_params.quantile_per_tree = true; - } - } - - if (tree_params.use_experimental_backend) { - dinfo.NLocalrows = nrows; - dinfo.NGlobalrows = nrows; - dinfo.Ncols = ncols; - n_unique_labels = unique_labels; - grow_tree(device_allocator_in, host_allocator_in, data, treeid, seed, ncols, - nrows, labels, d_global_quantiles, (int *)rowids, n_sampled_rows, - unique_labels, tree_params, stream_in, sparsetree, - this->leaf_counter, this->depth_counter); - } else { - if (treeid == 0) { - CUML_LOG_WARN( - "The old backend is deprecated and will be removed in 21.08 " - "release.\n"); - CUML_LOG_WARN("Using old backend for growing trees\n"); - } - plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, treeid, seed); - if (in_tempmem == nullptr) { - tempmem.reset(); - } - } + dinfo.NLocalrows = nrows; + dinfo.NGlobalrows = nrows; + dinfo.Ncols = ncols; + n_unique_labels = unique_labels; + grow_tree(device_allocator_in, host_allocator_in, data, treeid, seed, ncols, + nrows, labels, d_global_quantiles, (int *)rowids, n_sampled_rows, + unique_labels, tree_params, stream_in, sparsetree, + this->leaf_counter, this->depth_counter); } template @@ -446,13 +357,12 @@ void DecisionTreeClassifier::fit( const raft::handle_t &handle, const T *data, const int ncols, const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles, - std::shared_ptr> in_tempmem) { + DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles) { this->tree_params = tree_parameters; this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(), handle.get_stream(), data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, tree->sparsetree, tree->treeid, - seed, true, d_global_quantiles, in_tempmem); + seed, true, d_global_quantiles); this->set_metadata(tree); } @@ -464,28 +374,28 @@ void DecisionTreeClassifier::fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles, - std::shared_ptr> in_tempmem) { + DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles) { this->tree_params = tree_parameters; this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, - tree->sparsetree, tree->treeid, seed, true, d_global_quantiles, - in_tempmem); + tree->sparsetree, tree->treeid, seed, true, + d_global_quantiles); this->set_metadata(tree); } template -void DecisionTreeRegressor::fit( - const raft::handle_t &handle, const T *data, const int ncols, const int nrows, - const T *labels, unsigned int *rowids, const int n_sampled_rows, - TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, T *d_global_quantiles, - std::shared_ptr> in_tempmem) { +void DecisionTreeRegressor::fit(const raft::handle_t &handle, const T *data, + const int ncols, const int nrows, + const T *labels, unsigned int *rowids, + const int n_sampled_rows, + TreeMetaDataNode *&tree, + DecisionTreeParams tree_parameters, + uint64_t seed, T *d_global_quantiles) { this->tree_params = tree_parameters; this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(), handle.get_stream(), data, ncols, nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree, tree->treeid, seed, false, - d_global_quantiles, in_tempmem); + d_global_quantiles); this->set_metadata(tree); } @@ -496,52 +406,14 @@ void DecisionTreeRegressor::fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, T *d_global_quantiles, - std::shared_ptr> in_tempmem) { + uint64_t seed, T *d_global_quantiles) { this->tree_params = tree_parameters; this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols, nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree, - tree->treeid, seed, false, d_global_quantiles, in_tempmem); + tree->treeid, seed, false, d_global_quantiles); this->set_metadata(tree); } -template -void DecisionTreeClassifier::grow_deep_tree( - const T *data, const int *labels, unsigned int *rowids, - const int n_sampled_rows, const int ncols, const float colper, - const int nrows, std::vector> &sparsetree, - const int treeid, std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTreeClassifier::grow_deep_tree @decisiontree_impl.cuh"); - int leaf_cnt = 0; - int depth_cnt = 0; - grow_deep_tree_classification(data, labels, rowids, ncols, colper, - n_sampled_rows, nrows, this->n_unique_labels, - this->tree_params, depth_cnt, leaf_cnt, - sparsetree, treeid, tempmem); - this->depth_counter = depth_cnt; - this->leaf_counter = leaf_cnt; - ML::POP_RANGE(); -} - -template -void DecisionTreeRegressor::grow_deep_tree( - const T *data, const T *labels, unsigned int *rowids, - const int n_sampled_rows, const int ncols, const float colper, - const int nrows, std::vector> &sparsetree, - const int treeid, std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTreeRegressor::grow_deep_tree @decisiontree_impl.cuh"); - int leaf_cnt = 0; - int depth_cnt = 0; - grow_deep_tree_regression(data, labels, rowids, ncols, colper, n_sampled_rows, - nrows, this->tree_params, depth_cnt, leaf_cnt, - sparsetree, treeid, tempmem); - this->depth_counter = depth_cnt; - this->leaf_counter = leaf_cnt; - ML::POP_RANGE(); -} - //Class specializations template class DecisionTreeBase; template class DecisionTreeBase; diff --git a/cpp/src/decisiontree/decisiontree_impl.h b/cpp/src/decisiontree/decisiontree_impl.h index ed583b9f1a..40624221e8 100644 --- a/cpp/src/decisiontree/decisiontree_impl.h +++ b/cpp/src/decisiontree/decisiontree_impl.h @@ -79,7 +79,6 @@ class DecisionTreeBase { DataInfo dinfo; int depth_counter = 0; int leaf_counter = 0; - std::shared_ptr> tempmem; size_t total_temp_mem; const int MAXSTREAMS = 1; size_t max_shared_mem; @@ -90,17 +89,6 @@ class DecisionTreeBase { MLCommon::TimerCPU prepare_fit_timer; DecisionTreeParams tree_params; - void plant(std::vector> &sparsetree, const T *data, - const int ncols, const int nrows, const L *labels, - unsigned int *rowids, const int n_sampled_rows, int unique_labels, - const int treeid, uint64_t seed); - - virtual void grow_deep_tree( - const T *data, const L *labels, unsigned int *rowids, - const int n_sampled_rows, const int ncols, const float colper, - const int nrows, std::vector> &sparsetree, - const int treeid, std::shared_ptr> tempmem) = 0; - void base_fit( const std::shared_ptr device_allocator_in, const std::shared_ptr host_allocator_in, @@ -108,8 +96,7 @@ class DecisionTreeBase { const int nrows, const L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, std::vector> &sparsetree, const int treeid, - uint64_t seed, bool is_classifier, T *d_global_quantiles, - std::shared_ptr> in_tempmem); + uint64_t seed, bool is_classifier, T *d_global_quantiles); public: // Printing utility for high level tree info. @@ -142,8 +129,7 @@ class DecisionTreeClassifier : public DecisionTreeBase { const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, T *d_quantiles, - std::shared_ptr> in_tempmem = nullptr); + uint64_t seed, T *d_quantiles); //This fit fucntion does not take handle , used by RF void fit( @@ -153,17 +139,7 @@ class DecisionTreeClassifier : public DecisionTreeBase { const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, T *d_quantiles, - std::shared_ptr> in_tempmem); - - private: - void grow_deep_tree(const T *data, const int *labels, unsigned int *rowids, - const int n_sampled_rows, const int ncols, - const float colper, const int nrows, - std::vector> &sparsetree, - const int treeid, - std::shared_ptr> tempmem); - + uint64_t seed, T *d_quantiles); }; // End DecisionTreeClassifier Class template @@ -172,8 +148,7 @@ class DecisionTreeRegressor : public DecisionTreeBase { void fit(const raft::handle_t &handle, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles, - std::shared_ptr> in_tempmem = nullptr); + DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles); //This fit function does not take handle. Used by RF void fit( @@ -182,16 +157,7 @@ class DecisionTreeRegressor : public DecisionTreeBase { const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles, - std::shared_ptr> in_tempmem); - - private: - void grow_deep_tree(const T *data, const T *labels, unsigned int *rowids, - const int n_sampled_rows, const int ncols, - const float colper, const int nrows, - std::vector> &sparsetree, - const int treeid, - std::shared_ptr> tempmem); + DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles); }; // End DecisionTreeRegressor Class diff --git a/cpp/src/decisiontree/levelalgo/common_helper.cuh b/cpp/src/decisiontree/levelalgo/common_helper.cuh deleted file mode 100644 index 6795e00980..0000000000 --- a/cpp/src/decisiontree/levelalgo/common_helper.cuh +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include "common_kernel.cuh" - -#include - -namespace ML { -namespace DecisionTree { - -/*This functions does feature subsampling. - *The default is reshuffling of a feature list at ever level followed by random start index in the reshuffled vector for each node. - *In case full reshuffle is enabled. A reshuffle is performed for every node in the tree - */ -template -void update_feature_sampling(unsigned int *h_colids, unsigned int *d_colids, - unsigned int *h_colstart, unsigned int *d_colstart, - const int Ncols, const int ncols_sampled, - const int n_nodes, RNG rng, DIST dist, - std::vector &feature_selector, - std::shared_ptr> tempmem, - raft::random::Rng &d_rng) { - ML::PUSH_RANGE( - "update_feature_sampling @common_helper.cuh (does feature subsampling)"); - if (h_colstart != nullptr) { - if (Ncols != ncols_sampled) { - std::shuffle(h_colids, h_colids + Ncols, rng); - raft::update_device(d_colids, h_colids, Ncols, tempmem->stream); - if (n_nodes < 256 * tempmem->num_sms) { - for (int i = 0; i < n_nodes; i++) { - h_colstart[i] = dist(rng); - } - raft::update_device(d_colstart, h_colstart, n_nodes, tempmem->stream); - } else { - d_rng.uniformInt(d_colstart, n_nodes, 0, Ncols, - tempmem->stream); - raft::update_host(h_colstart, d_colstart, n_nodes, tempmem->stream); - } - } - } else { - for (int i = 0; i < n_nodes; i++) { - std::vector temp(feature_selector); - std::shuffle(temp.begin(), temp.end(), rng); - memcpy(&h_colids[i * ncols_sampled], temp.data(), - ncols_sampled * sizeof(unsigned int)); - } - raft::update_device(d_colids, h_colids, ncols_sampled * n_nodes, - tempmem->stream); - } - ML::POP_RANGE(); -} - -//This function calcualtes min/max from the samples that belong in a given node. This is done for all the nodes at a given level -template -void get_minmax(const T *data, const unsigned int *flags, - const unsigned int *colids, const unsigned int *colstart, - const int nrows, const int Ncols, const int ncols_sampled, - const int n_nodes, const int max_shmem_nodes, T *d_minmax, - T *h_minmax, cudaStream_t &stream) { - using E = typename MLCommon::Stats::encode_traits::E; - T init_val = std::numeric_limits::max(); - int threads = 128; - int nblocks = raft::ceildiv(2 * ncols_sampled * n_nodes, threads); - minmax_init_kernel<<>>( - d_minmax, ncols_sampled * n_nodes, n_nodes, init_val); - CUDA_CHECK(cudaGetLastError()); - - nblocks = raft::ceildiv(nrows, threads); - if (n_nodes <= max_shmem_nodes) { - get_minmax_kernel - <<>>( - data, flags, colids, colstart, nrows, Ncols, ncols_sampled, n_nodes, - init_val, d_minmax); - } else { - get_minmax_kernel_global<<>>( - data, flags, colids, colstart, nrows, Ncols, ncols_sampled, n_nodes, - d_minmax); - } - CUDA_CHECK(cudaGetLastError()); - - nblocks = raft::ceildiv(2 * ncols_sampled * n_nodes, threads); - minmax_decode_kernel - <<>>(d_minmax, ncols_sampled * n_nodes); - - CUDA_CHECK(cudaGetLastError()); - raft::update_host(h_minmax, d_minmax, 2 * n_nodes * ncols_sampled, stream); -} -// This function does setup for flags. and count. -void setup_sampling(unsigned int *flagsptr, unsigned int *sample_cnt, - const unsigned int *rowids, const int nrows, - const int n_sampled_rows, cudaStream_t &stream) { - ML::PUSH_RANGE("DecisionTree::setup_sampling @common_helper.cuh"); - CUDA_CHECK(cudaMemsetAsync(sample_cnt, 0, nrows * sizeof(int), stream)); - int threads = 256; - int blocks = raft::ceildiv(n_sampled_rows, threads); - setup_counts_kernel<<>>(sample_cnt, rowids, - n_sampled_rows); - CUDA_CHECK(cudaGetLastError()); - blocks = raft::ceildiv(nrows, threads); - setup_flags_kernel<<>>(sample_cnt, flagsptr, - nrows); - CUDA_CHECK(cudaGetLastError()); - ML::POP_RANGE(); //setup_sampling @common_helper.cuh -} - -//This function call the split kernel -template -void make_level_split(const T *data, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, - const int n_nodes, const int split_algo, - int *split_colidx, int *split_binidx, - const unsigned int *new_node_flags, unsigned int *flags, - std::shared_ptr> tempmem) { - int threads = 256; - int blocks = raft::ceildiv(nrows, threads); - unsigned int *d_colstart = nullptr; - if (tempmem->d_colstart != nullptr) d_colstart = tempmem->d_colstart->data(); - if (split_algo == 0) { - split_level_kernel> - <<stream>>>( - data, tempmem->d_globalminmax->data(), tempmem->d_colids->data(), - d_colstart, split_colidx, split_binidx, nrows, Ncols, ncols_sampled, - nbins, n_nodes, new_node_flags, flags); - } else { - split_level_kernel> - <<stream>>>( - data, tempmem->d_quantile->data(), tempmem->d_colids->data(), - d_colstart, split_colidx, split_binidx, nrows, Ncols, ncols_sampled, - nbins, n_nodes, new_node_flags, flags); - } - CUDA_CHECK(cudaGetLastError()); -} - -/* node_hist[i] holds the # times label i appear in current data. The vector is computed during gini - computation. */ -int get_class_hist(unsigned int *node_hist, const int n_unique_labels) { - unsigned int maxval = node_hist[0]; - int classval = 0; - for (int i = 1; i < n_unique_labels; i++) { - if (node_hist[i] > maxval) { - maxval = node_hist[i]; - classval = i; - } - } - return classval; -} - -template -T getQuesValue(const T *minmax, const T *quantile, const int nbins, - const int colid, const int binid, const int nodeid, - const int n_nodes, const int featureid, const int split_algo) { - if (split_algo == 0) { - T min = minmax[nodeid + colid * n_nodes * 2]; - T delta = (minmax[nodeid + n_nodes + colid * n_nodes * 2] - min) / nbins; - return (min + delta * (binid + 1)); - } else { - return quantile[featureid * nbins + binid]; - } -} - -unsigned int getQuesColumn(const unsigned int *colids, const int colstart_local, - const int Ncols, const int ncols_sampled, - const int colidx, const int nodeid) { - unsigned int col; - if (colstart_local != -1) { - col = colids[(colstart_local + colidx) % Ncols]; - } else { - col = colids[nodeid * ncols_sampled + colidx]; - } - return col; -} -template -void convert_scatter_to_gather(const unsigned int *flagsptr, - const unsigned int *sample_cnt, - const int n_nodes, const int n_rows, - unsigned int *nodecount, unsigned int *nodestart, - unsigned int *samplelist, - std::shared_ptr> tempmem) { - CUDA_CHECK(cudaMemsetAsync(nodestart, 0, (n_nodes + 1) * sizeof(unsigned int), - tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(nodecount, 0, (n_nodes + 1) * sizeof(unsigned int), - tempmem->stream)); - - int nthreads = 128; - int nblocks = raft::ceildiv(n_rows, nthreads); - fill_counts<<stream>>>(flagsptr, sample_cnt, - n_rows, nodecount); - - void *d_temp_storage = (void *)(tempmem->temp_cub_buffer->data()); - cub::DeviceScan::ExclusiveSum(d_temp_storage, tempmem->temp_cub_bytes, - nodecount, nodestart, n_nodes + 1, - tempmem->stream); - CUDA_CHECK(cudaGetLastError()); - unsigned int *h_nodestart = (unsigned int *)(tempmem->h_split_binidx->data()); - raft::update_host(h_nodestart, nodestart + n_nodes, 1, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(nodecount, 0, n_nodes * sizeof(unsigned int), - tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync( - samplelist, 0, h_nodestart[0] * sizeof(unsigned int), tempmem->stream)); - build_list<<stream>>>( - flagsptr, nodestart, n_rows, nodecount, samplelist); - CUDA_CHECK(cudaGetLastError()); -} -template -void print_convertor(unsigned int *d_nodecount, unsigned int *d_nodestart, - unsigned int *d_samplelist, int n_nodes, - std::shared_ptr> tempmem) { - unsigned int *nodecount = (unsigned int *)(tempmem->h_split_colidx->data()); - unsigned int *nodestart = (unsigned int *)(tempmem->h_split_binidx->data()); - unsigned int *samplelist = (unsigned int *)(tempmem->h_parent_metric->data()); - raft::update_host(nodecount, d_nodecount, n_nodes + 1, tempmem->stream); - raft::update_host(nodestart, d_nodestart, n_nodes + 1, tempmem->stream); - CUDA_CHECK(cudaDeviceSynchronize()); - ML::PatternSetter _("%v"); - CUML_LOG_DEBUG("Full sample list size %u", nodestart[n_nodes]); - raft::update_host(samplelist, d_samplelist, nodestart[n_nodes], - tempmem->stream); - CUDA_CHECK(cudaDeviceSynchronize()); - - { - std::stringstream ss; - ss << "Printing node count\n"; - for (int i = 0; i < n_nodes + 1; i++) { - ss << nodecount[i] << " "; - } - CUML_LOG_DEBUG(ss.str().c_str()); - } - { - std::stringstream ss; - ss << "Printing node start\n"; - for (int i = 0; i < n_nodes + 1; i++) { - ss << nodestart[i] << " "; - } - CUML_LOG_DEBUG(ss.str().c_str()); - } - { - std::stringstream ss; - ss << "Printing sample list\n"; - for (int i = 0; i < n_nodes; i++) { - ss << "Node id " << i << " --> "; - for (int j = nodestart[i]; j < nodestart[i + 1]; j++) { - ss << samplelist[j] << " "; - } - } - CUML_LOG_DEBUG(ss.str().c_str()); - } -} - -template -void print_nodes(SparseTreeNode *sparsenodes, float *gain, int *nodelist, - int n_nodes, std::shared_ptr> tempmem) { - CUDA_CHECK(cudaDeviceSynchronize()); - ML::PatternSetter _("%v"); - CUML_LOG_DEBUG( - "Node format --> (colid, quesval, best_metric, prediction, left_child) "); - int *h_nodelist = (int *)(tempmem->h_outgain->data()); - if (nodelist != nullptr) { - raft::update_host(h_nodelist, nodelist, n_nodes, tempmem->stream); - CUDA_CHECK(cudaDeviceSynchronize()); - } - for (int i = 0; i < n_nodes; i++) { - int nodeid = i; - if (nodelist != nullptr) nodeid = h_nodelist[i]; - SparseTreeNode &node = sparsenodes[nodeid]; - std::stringstream ss; - ss << "Node id " << i << " --> (" << node.colid << " ," << node.quesval - << " ," << node.best_metric_val << ", "; - ss << node.prediction << " ," << node.left_child_id << " )"; - if (gain != nullptr) ss << " gain -->" << gain[i]; - CUML_LOG_DEBUG(ss.str().c_str()); - } -} - -template -void make_split_gather(const T *data, unsigned int *nodestart, - unsigned int *samplelist, const int n_nodes, - const int nrows, const int *nodelist, int *new_nodelist, - unsigned int *nodecount, int *counter, - unsigned int *flagsptr, - const SparseTreeNode *d_sparsenodes, - std::shared_ptr> tempmem) { - CUDA_CHECK(cudaMemsetAsync( - nodecount, 0, (2 * n_nodes + 1) * sizeof(unsigned int), tempmem->stream)); - CUDA_CHECK( - cudaMemsetAsync(counter, 0, sizeof(unsigned int), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(flagsptr, LEAF, nrows * sizeof(unsigned int), - tempmem->stream)); - int nthreads = 128; - int nblocks = raft::ceildiv(nrows, nthreads); - split_nodes_compute_counts_kernel<<), - tempmem->stream>>>( - data, d_sparsenodes, nodestart, samplelist, nrows, nodelist, new_nodelist, - nodecount, counter, flagsptr); - CUDA_CHECK(cudaGetLastError()); - void *d_temp_storage = (void *)(tempmem->temp_cub_buffer->data()); - int *h_counter = tempmem->h_counter->data(); - raft::update_host(h_counter, counter, 1, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - cub::DeviceScan::ExclusiveSum(d_temp_storage, tempmem->temp_cub_bytes, - nodecount, nodestart, h_counter[0] + 1, - tempmem->stream); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaMemsetAsync(samplelist, 0, h_counter[0] * sizeof(unsigned int), - tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(nodecount, 0, h_counter[0] * sizeof(unsigned int), - tempmem->stream)); - build_list<<stream>>>( - flagsptr, nodestart, nrows, nodecount, samplelist); - CUDA_CHECK(cudaGetLastError()); -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/common_kernel.cuh b/cpp/src/decisiontree/levelalgo/common_kernel.cuh deleted file mode 100644 index 20b1abeda6..0000000000 --- a/cpp/src/decisiontree/levelalgo/common_kernel.cuh +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#define LEAF 0xFFFFFFFF -#define PUSHRIGHT 0x00000001 -#include - -namespace ML { -namespace DecisionTree { - -template -DI T get_data(const T* __restrict__ data, const T local_data, - const unsigned int dataid, const unsigned int count) { - if (count <= blockDim.x) { - return local_data; - } else { - return data[dataid]; - } -} - -DI unsigned int get_samplelist(const unsigned int* __restrict__ samplelist, - const unsigned int dataid, - const unsigned int nodestart, const int tid, - const unsigned int count) { - if (count <= blockDim.x) { - return dataid; - } else { - return samplelist[nodestart + tid]; - } -} - -template -DI L get_label(const L* __restrict__ labels, const L local_label, - const unsigned int dataid, const unsigned int count) { - if (count <= blockDim.x) { - return local_label; - } else { - return labels[dataid]; - } -} -DI int get_class_hist_shared(unsigned int* node_hist, - const int n_unique_labels) { - unsigned int maxval = node_hist[0]; - int classval = 0; - for (int i = 1; i < n_unique_labels; i++) { - if (node_hist[i] > maxval) { - maxval = node_hist[i]; - classval = i; - } - } - return classval; -} -__global__ void fill_all_leaf(unsigned int* flags, const int nrows) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < nrows) { - flags[tid] = LEAF; - } -} -DI unsigned int get_column_id(const unsigned int* __restrict__ colids, - const int& colstart_local, const int& Ncols, - const int& ncols_sampled, - const unsigned int& colcnt, - const unsigned int& local_flag) { - unsigned int col; - if (colstart_local != -1) { - col = colids[(colstart_local + colcnt) % Ncols]; - } else { - col = colids[local_flag * ncols_sampled + colcnt]; - } - return col; -} -template -__global__ void minmax_init_kernel(T* minmax, const int len, const int n_nodes, - const T init_val) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < 2 * len) { - bool ifmin = (((int)(tid / n_nodes) % 2) == 0); - *(E*)&minmax[tid] = (ifmin) ? MLCommon::Stats::encode(init_val) - : MLCommon::Stats::encode(-init_val); - } -} - -template -__global__ void minmax_decode_kernel(T* minmax, const int len) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < 2 * len) { - minmax[tid] = MLCommon::Stats::decode(*(E*)&minmax[tid]); - } -} - -//This kernel calculates minmax at node level -template -__global__ void get_minmax_kernel(const T* __restrict__ data, - const unsigned int* __restrict__ flags, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, - const int nrows, const int Ncols, - const int ncols_sampled, const int n_nodes, - T init_min_val, T* minmax) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - unsigned int local_flag = LEAF; - int colstart_local = -1; - extern __shared__ char shared_mem_minmax[]; - T* shmem_minmax = (T*)shared_mem_minmax; - if (tid < nrows) { - local_flag = flags[tid]; - } - if (local_flag != LEAF && colstart != nullptr) { - colstart_local = colstart[local_flag]; - } - for (int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - for (int i = threadIdx.x; i < 2 * n_nodes; i += blockDim.x) { - *(E*)&shmem_minmax[i] = (i < n_nodes) - ? MLCommon::Stats::encode(init_min_val) - : MLCommon::Stats::encode(-init_min_val); - } - - __syncthreads(); - if (local_flag != LEAF) { - int col = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - colcnt, local_flag); - T local_data = data[col * nrows + tid]; - if (!isnan(local_data)) { - //Min max values are saved in shared memory and global memory as per the shuffled colids. - MLCommon::Stats::atomicMinBits(&shmem_minmax[local_flag], - local_data); - MLCommon::Stats::atomicMaxBits( - &shmem_minmax[local_flag + n_nodes], local_data); - } - } - __syncthreads(); - - //finally, perform global mem atomics - for (int i = threadIdx.x; i < n_nodes; i += blockDim.x) { - atomicMin((E*)&minmax[i + 2 * n_nodes * colcnt], *(E*)&shmem_minmax[i]); - atomicMax((E*)&minmax[i + n_nodes + 2 * n_nodes * colcnt], - *(E*)&shmem_minmax[i + n_nodes]); - } - __syncthreads(); - } -} - -template -__global__ void get_minmax_kernel_global( - const T* __restrict__ data, const unsigned int* __restrict__ flags, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int n_nodes, T* minmax) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - unsigned int local_flag = LEAF; - if (tid < nrows) { - local_flag = flags[tid]; - if (local_flag != LEAF) { - int colstart_local = -1; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - for (int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - int coloff = 2 * n_nodes * colcnt; - int col = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - colcnt, local_flag); - T local_data = data[col * nrows + tid]; - if (!isnan(local_data)) { - //Min max values are saved in shared memory and global memory as per the shuffled colids. - MLCommon::Stats::atomicMinBits(&minmax[coloff + local_flag], - local_data); - MLCommon::Stats::atomicMaxBits( - &minmax[coloff + n_nodes + local_flag], local_data); - } - } - } - } -} -//Setup how many times a sample is being used. -//This is due to bootstrap nature of Random Forest. -__global__ void setup_counts_kernel(unsigned int* sample_cnt, - const unsigned int* __restrict__ rowids, - const int n_sampled_rows) { - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - for (int tid = threadid; tid < n_sampled_rows; - tid += blockDim.x * gridDim.x) { - unsigned int stid = rowids[tid]; - raft::myAtomicAdd(&sample_cnt[stid], 1); - } -} -//This initializes the flags to 0x00000000. IF a sample is not used at all we Leaf out. -__global__ void setup_flags_kernel(const unsigned int* __restrict__ sample_cnt, - unsigned int* flags, const int nrows) { - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - for (int tid = threadid; tid < nrows; tid += blockDim.x * gridDim.x) { - unsigned int local_cnt = sample_cnt[tid]; - unsigned int local_flag = LEAF; - if (local_cnt != 0) local_flag = 0x00000000; - flags[tid] = local_flag; - } -} - -// This make actual split. A split is done using bits. -//Least significant Bit 0 means left and 1 means right. -//As a result a max depth of 32 is supported for now. -template -__global__ void split_level_kernel( - const T* __restrict__ data, const T* __restrict__ question_ptr, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, - const int* __restrict__ split_col_index, - const int* __restrict__ split_bin_index, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const unsigned int* __restrict__ new_node_flags, - unsigned int* __restrict__ flags) { - unsigned int threadid = threadIdx.x + blockIdx.x * blockDim.x; - unsigned int local_flag = LEAF; - - for (int tid = threadid; tid < nrows; tid += gridDim.x * blockDim.x) { - local_flag = flags[tid]; - - if (local_flag != LEAF) { - unsigned int local_leaf_flag = new_node_flags[local_flag]; - if (local_leaf_flag != LEAF) { - int colidx = split_col_index[local_flag]; - int local_colstart = -1; - if (colstart != nullptr) local_colstart = colstart[local_flag]; - int colid = get_column_id(colids, local_colstart, Ncols, ncols_sampled, - colidx, local_flag); - QuestionType question(question_ptr, colid, colidx, n_nodes, local_flag, - nbins); - T quesval = question(split_bin_index[local_flag]); - T local_data = data[colid * nrows + tid]; - //The inverse comparision here to push right instead of left - if (local_data <= quesval) { - local_flag = local_leaf_flag << 1; - } else { - local_flag = (local_leaf_flag << 1) | PUSHRIGHT; - } - } else { - local_flag = LEAF; - } - flags[tid] = local_flag; - } - } -} - -struct GainIdxPair { - float gain; - int idx; -}; - -template -struct ReducePair { - KeyReduceOp op; - DI ReducePair() {} - DI ReducePair(KeyReduceOp op) : op(op) {} - DI GainIdxPair operator()(const GainIdxPair& a, const GainIdxPair& b) { - GainIdxPair retval; - retval.gain = op(a.gain, b.gain); - if (retval.gain == a.gain) { - retval.idx = a.idx; - } else { - retval.idx = b.idx; - } - return retval; - } -}; - -template -struct QuantileQues { - const T* __restrict__ quantile; - DI QuantileQues(const T* __restrict__ quantile_ptr, const unsigned int colid, - const unsigned int colcnt, const int n_nodes, - const unsigned int nodeid, const int nbins) - : quantile(quantile_ptr + colid * nbins) {} - - DI T operator()(const int binid) { return quantile[binid]; } -}; - -template -struct MinMaxQues { - T min, delta; - DI MinMaxQues(const T* __restrict__ minmax_ptr, const unsigned int colid, - const unsigned int colcnt, const int n_nodes, - const unsigned int nodeid, const int nbins) { - int off = colcnt * 2 * n_nodes + nodeid; - min = minmax_ptr[off]; - delta = (minmax_ptr[off + n_nodes] - min) / nbins; - } - - DI T operator()(const int binid) { return (min + (binid + 1) * delta); } -}; - -__global__ void fill_counts(const unsigned int* __restrict__ flagsptr, - const unsigned int* __restrict__ sample_cnt, - const int n_rows, unsigned int* nodecount) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < n_rows) { - unsigned int nodeid = flagsptr[tid]; - if (nodeid != LEAF) { - unsigned int count = sample_cnt[tid]; - raft::myAtomicAdd(&nodecount[nodeid], count); - } - } -} - -__global__ void build_list(const unsigned int* __restrict__ flagsptr, - const unsigned int* __restrict__ nodestart, - const int n_rows, unsigned int* nodecount, - unsigned int* samplelist) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < n_rows) { - unsigned int nodeid = flagsptr[tid]; - if (nodeid != LEAF) { - unsigned int start = nodestart[nodeid]; - unsigned int currcnt = atomicAdd(&nodecount[nodeid], 1); - samplelist[start + currcnt] = tid; - } - } -} -template -__global__ void split_nodes_compute_counts_kernel( - const T* __restrict__ data, - const SparseTreeNode* __restrict__ d_sparsenodes, - const unsigned int* __restrict__ nodestart, - const unsigned int* __restrict__ samplelist, const int nrows, - const int* __restrict__ nodelist, int* new_nodelist, - unsigned int* samplecount, int* nodecounter, unsigned int* flagsptr) { - __shared__ int currcnt; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char shmem[]; - SparseTreeNode* localnode = (SparseTreeNode*)shmem; - if (threadIdx.x == 0) { - localnode[0] = d_sparsenodes[nodelist[blockIdx.x]]; - } - __syncthreads(); - int colid = localnode->colid; - if (colid != -1) { - unsigned int nstart = nodestart[blockIdx.x]; - unsigned int ncount = nodestart[blockIdx.x + 1] - nstart; - if (threadIdx.x == 0) { - currcnt = atomicAdd(nodecounter, 2); - } - __syncthreads(); - if (threadIdx.x < 2) { - new_nodelist[currcnt + threadIdx.x] = 2 * blockIdx.x + threadIdx.x; - } - int tid_count = 0; - T quesval = localnode->quesval; - for (int tid = threadIdx.x; tid < ncount; tid += blockDim.x) { - unsigned int dataid = samplelist[nstart + tid]; - if (data[colid * nrows + dataid] <= quesval) { - tid_count++; - flagsptr[dataid] = (unsigned int)(currcnt); - } else { - flagsptr[dataid] = (unsigned int)(currcnt + 1); - } - } - int cnt_left = BlockReduce(temp_storage).Sum(tid_count); - __syncthreads(); - if (threadIdx.x == 0) { - samplecount[currcnt] = cnt_left; - samplecount[currcnt + 1] = ncount - cnt_left; - } - } -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelfunc_classifier.cuh b/cpp/src/decisiontree/levelalgo/levelfunc_classifier.cuh deleted file mode 100644 index d33c82184e..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelfunc_classifier.cuh +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include "common_helper.cuh" -#include "levelhelper_classifier.cuh" -#include "metric.cuh" - -#include - -namespace ML { -namespace DecisionTree { - -/* -This is the driver function for building classification tree -level by level using a simple for loop. -At each level; following steps are involved. -1. Compute histograms for all nodes, all cols and all bins. -2. Find best split col and bin for each node. -3. Check info gain and then leaf out nodes as needed. -4. make split. -*/ -template -void grow_deep_tree_classification( - const T* data, const int* labels, unsigned int* rowids, const int Ncols, - const float colper, int n_sampled_rows, const int nrows, - const int n_unique_labels, - const ML::DecisionTree::DecisionTreeParams& tree_params, int& depth_cnt, - int& leaf_cnt, std::vector>& sparsetree, - const int treeid, std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTree::grow_deep_tree_classification @levelfunc_classifier.cuh"); - const int ncols_sampled = (int)(colper * Ncols); - unsigned int* flagsptr = tempmem->d_flags->data(); - unsigned int* sample_cnt = tempmem->d_sample_cnt->data(); - setup_sampling(flagsptr, sample_cnt, rowids, nrows, n_sampled_rows, - tempmem->stream); - std::vector histvec(n_unique_labels, 0); - T initial_metric; - if (tree_params.split_criterion == ML::CRITERION::GINI) { - initial_metric_classification(labels, sample_cnt, nrows, - n_unique_labels, histvec, - initial_metric, tempmem); - } else { - initial_metric_classification(labels, sample_cnt, nrows, - n_unique_labels, histvec, - initial_metric, tempmem); - } - int reserve_depth = std::min(tempmem->swap_depth, tree_params.max_depth + 1); - size_t total_nodes = pow(2, (reserve_depth + 1)) - 1; - - unsigned int* h_parent_hist = tempmem->h_parent_hist->data(); - unsigned int* h_child_hist = tempmem->h_child_hist->data(); - memcpy(h_parent_hist, histvec.data(), n_unique_labels * sizeof(int)); - - sparsetree.reserve(total_nodes); - SparseTreeNode sparsenode; - sparsenode.best_metric_val = initial_metric; - sparsetree.push_back(sparsenode); - int sparsesize = 0; - int sparsesize_nextitr = 0; - - int n_nodes = 1; - int n_nodes_nextitr = 1; - std::vector sparse_nodelist; - sparse_nodelist.reserve(tempmem->max_nodes_per_level); - sparse_nodelist.push_back(0); - - //RNG setup - std::mt19937 mtg(treeid * 1000); - raft::random::Rng d_rng(treeid * 1000); - std::uniform_int_distribution dist(0, Ncols - 1); - //Setup pointers - unsigned int* d_histogram = tempmem->d_histogram->data(); - unsigned int* h_histogram = tempmem->h_histogram->data(); - int* h_split_binidx = tempmem->h_split_binidx->data(); - int* d_split_binidx = tempmem->d_split_binidx->data(); - int* h_split_colidx = tempmem->h_split_colidx->data(); - int* d_split_colidx = tempmem->d_split_colidx->data(); - unsigned int* h_new_node_flags = tempmem->h_new_node_flags->data(); - unsigned int* d_new_node_flags = tempmem->d_new_node_flags->data(); - unsigned int* d_colids = tempmem->d_colids->data(); - unsigned int* h_colids = tempmem->h_colids->data(); - unsigned int* d_colstart = nullptr; - unsigned int* h_colstart = nullptr; - if (tempmem->d_colstart != nullptr) { - d_colstart = tempmem->d_colstart->data(); - h_colstart = tempmem->h_colstart->data(); - CUDA_CHECK(cudaMemsetAsync( - d_colstart, 0, tempmem->max_nodes_per_level * sizeof(unsigned int), - tempmem->stream)); - memset(h_colstart, 0, tempmem->max_nodes_per_level * sizeof(unsigned int)); - raft::update_device(d_colids, h_colids, Ncols, tempmem->stream); - } - std::vector feature_selector(h_colids, h_colids + Ncols); - - int scatter_algo_depth = - std::min(tempmem->swap_depth, tree_params.max_depth + 1); - ML::PUSH_RANGE("scatter phase @levelfunc_classifier"); - for (int depth = 0; (depth < scatter_algo_depth) && (n_nodes_nextitr != 0); - depth++) { - depth_cnt = depth; - n_nodes = n_nodes_nextitr; - sparsesize = sparsesize_nextitr; - sparsesize_nextitr = sparsetree.size(); - ASSERT( - n_nodes <= tempmem->max_nodes_per_level, - "Max node limit reached. Requested nodes %d > %d max nodes at depth %d\n", - n_nodes, tempmem->max_nodes_per_level, depth); - - update_feature_sampling(h_colids, d_colids, h_colstart, d_colstart, Ncols, - ncols_sampled, n_nodes, mtg, dist, feature_selector, - tempmem, d_rng); - get_histogram_classification(data, labels, flagsptr, sample_cnt, nrows, - Ncols, ncols_sampled, n_unique_labels, - tree_params.n_bins, n_nodes, - tree_params.split_algo, tempmem, d_histogram); - - float* infogain = tempmem->h_outgain->data(); - if (tree_params.split_criterion == ML::CRITERION::GINI) { - get_best_split_classification( - h_histogram, d_histogram, h_colids, d_colids, h_colstart, d_colstart, - Ncols, ncols_sampled, tree_params.n_bins, n_unique_labels, n_nodes, - depth, tree_params.min_samples_leaf, tree_params.split_algo, infogain, - h_parent_hist, h_child_hist, sparsetree, sparsesize, sparse_nodelist, - h_split_colidx, h_split_binidx, d_split_colidx, d_split_binidx, - tempmem); - } else { - get_best_split_classification( - h_histogram, d_histogram, h_colids, d_colids, h_colstart, d_colstart, - Ncols, ncols_sampled, tree_params.n_bins, n_unique_labels, n_nodes, - depth, tree_params.min_samples_leaf, tree_params.split_algo, infogain, - h_parent_hist, h_child_hist, sparsetree, sparsesize, sparse_nodelist, - h_split_colidx, h_split_binidx, d_split_colidx, d_split_binidx, - tempmem); - } - - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - leaf_eval_classification( - infogain, depth, tree_params.min_impurity_decrease, tree_params.max_depth, - n_unique_labels, tree_params.max_leaves, h_new_node_flags, sparsetree, - sparsesize, h_parent_hist, n_nodes_nextitr, sparse_nodelist, leaf_cnt); - - raft::update_device(d_new_node_flags, h_new_node_flags, n_nodes, - tempmem->stream); - make_level_split(data, nrows, Ncols, ncols_sampled, tree_params.n_bins, - n_nodes, tree_params.split_algo, d_split_colidx, - d_split_binidx, d_new_node_flags, flagsptr, tempmem); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - if (depth != (scatter_algo_depth - 1)) { - memcpy(h_parent_hist, h_child_hist, - 2 * n_nodes * n_unique_labels * sizeof(unsigned int)); - } - } - ML::POP_RANGE(); //scatter phase @levelfunc_classifier.cuh - - ML::PUSH_RANGE("gather phase @levelfunc_classifier.cuh"); - // Start of gather algorithm - //Convertor - CUML_LOG_DEBUG("begin gather "); - int lastsize = sparsetree.size() - sparsesize_nextitr; - n_nodes = n_nodes_nextitr; - if (n_nodes == 0) { - ML::POP_RANGE(); //gather phase ended - ML::POP_RANGE(); //grow_deep_tree_classification end - return; - } - unsigned int *d_nodecount, *d_samplelist, *d_nodestart; - SparseTreeNode* d_sparsenodes; - SparseTreeNode* h_sparsenodes; - int *h_nodelist, *d_nodelist, *d_new_nodelist; - int max_nodes = tempmem->max_nodes_per_level; - d_nodecount = (unsigned int*)(tempmem->d_child_best_metric->data()); - d_nodestart = (unsigned int*)(tempmem->d_split_binidx->data()); - d_samplelist = (unsigned int*)(tempmem->d_parent_metric->data()); - d_nodelist = (int*)(tempmem->d_outgain->data()); - d_new_nodelist = (int*)(tempmem->d_split_colidx->data()); - h_nodelist = (int*)(tempmem->h_outgain->data()); - d_sparsenodes = tempmem->d_sparsenodes->data(); - h_sparsenodes = tempmem->h_sparsenodes->data(); - - int* h_counter = tempmem->h_counter->data(); - int* d_counter = tempmem->d_counter->data(); - memcpy(h_nodelist, sparse_nodelist.data(), - sizeof(int) * sparse_nodelist.size()); - raft::update_device(d_nodelist, h_nodelist, sparse_nodelist.size(), - tempmem->stream); - //Resize to remove trailing nodes from previous algorithm - sparsetree.resize(sparsetree.size() - lastsize); - convert_scatter_to_gather(flagsptr, sample_cnt, n_nodes, nrows, d_nodecount, - d_nodestart, d_samplelist, tempmem); - if (tempmem->swap_depth == tree_params.max_depth) { - ++depth_cnt; - } - for (int depth = tempmem->swap_depth; - (depth < tree_params.max_depth) && (n_nodes != 0); depth++) { - depth_cnt = depth + 1; - //Algorithm starts here - update_feature_sampling(h_colids, d_colids, h_colstart, d_colstart, Ncols, - ncols_sampled, lastsize, mtg, dist, - feature_selector, tempmem, d_rng); - - if (tree_params.split_criterion == ML::CRITERION::GINI) { - best_split_gather_classification( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, nrows, - Ncols, ncols_sampled, n_unique_labels, tree_params.n_bins, n_nodes, - tree_params.split_algo, sparsetree.size() + lastsize, - tree_params.min_impurity_decrease, tempmem, d_sparsenodes, d_nodelist); - } else { - best_split_gather_classification( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, nrows, - Ncols, ncols_sampled, n_unique_labels, tree_params.n_bins, n_nodes, - tree_params.split_algo, sparsetree.size() + lastsize, - tree_params.min_impurity_decrease, tempmem, d_sparsenodes, d_nodelist); - } - raft::update_host(h_sparsenodes, d_sparsenodes, lastsize, tempmem->stream); - //Update nodelist and split nodes - - make_split_gather(data, d_nodestart, d_samplelist, n_nodes, nrows, - d_nodelist, d_new_nodelist, d_nodecount, d_counter, - flagsptr, d_sparsenodes, tempmem); - CUDA_CHECK(cudaMemcpyAsync(d_nodelist, d_new_nodelist, - h_counter[0] * sizeof(int), - cudaMemcpyDeviceToDevice, tempmem->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - sparsetree.insert(sparsetree.end(), h_sparsenodes, - h_sparsenodes + lastsize); - lastsize = 2 * n_nodes; - n_nodes = h_counter[0]; - } - if (n_nodes != 0) { - if (tree_params.split_criterion == ML::CRITERION::GINI) { - make_leaf_gather_classification( - labels, d_nodestart, d_samplelist, n_unique_labels, d_sparsenodes, - d_nodelist, n_nodes, tempmem); - } else { - make_leaf_gather_classification( - labels, d_nodestart, d_samplelist, n_unique_labels, d_sparsenodes, - d_nodelist, n_nodes, tempmem); - } - raft::update_host(h_sparsenodes, d_sparsenodes, lastsize, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - sparsetree.insert(sparsetree.end(), h_sparsenodes, - h_sparsenodes + lastsize); - } - - ML::POP_RANGE(); //gather phase @levelfunc_classifier.cuh - ML::POP_RANGE(); //grow_deep_tree_classification @levelfunc_classifier.cuh -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelfunc_regressor.cuh b/cpp/src/decisiontree/levelalgo/levelfunc_regressor.cuh deleted file mode 100644 index 6a91229f34..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelfunc_regressor.cuh +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include -#include "common_helper.cuh" -#include "levelhelper_regressor.cuh" -#include "metric.cuh" - -namespace ML { -namespace DecisionTree { - -/* -This is the driver function for building regression tree -level by level using a simple for loop. -At each level; following steps are involved. -1. Set up parent node mean and counts -2. Compute means and counts for all nodes, all cols and all bins. -3. Find best split col and bin for each node. -4. Check info gain and then leaf out nodes as needed. -5. make split. -*/ -template -void grow_deep_tree_regression( - const T* data, const T* labels, unsigned int* rowids, const int Ncols, - const float colper, const int n_sampled_rows, const int nrows, - const ML::DecisionTree::DecisionTreeParams& tree_params, int& depth_cnt, - int& leaf_cnt, std::vector>& sparsetree, - const int treeid, std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTree::grow_deep_tree_classification @levelfunc_regressor.cuh"); - const int ncols_sampled = (int)(colper * Ncols); - unsigned int* flagsptr = tempmem->d_flags->data(); - unsigned int* sample_cnt = tempmem->d_sample_cnt->data(); - setup_sampling(flagsptr, sample_cnt, rowids, nrows, n_sampled_rows, - tempmem->stream); - - T mean; - T initial_metric; - unsigned int count; - if (tree_params.split_criterion == ML::CRITERION::MSE) { - initial_metric_regression(labels, sample_cnt, nrows, mean, - count, initial_metric, tempmem); - } else { - initial_metric_regression(labels, sample_cnt, nrows, mean, - count, initial_metric, tempmem); - } - int reserve_depth = std::min(tempmem->swap_depth, tree_params.max_depth + 1); - size_t total_nodes = pow(2, (reserve_depth + 1)) - 1; - - std::vector sparse_meanstate; - std::vector sparse_countstate; - sparse_meanstate.resize(total_nodes, 0.0); - sparse_countstate.resize(total_nodes, 0); - sparse_meanstate[0] = mean; - sparse_countstate[0] = count; - - sparsetree.reserve(total_nodes); - SparseTreeNode sparsenode; - sparsenode.best_metric_val = initial_metric; - sparsetree.push_back(sparsenode); - int sparsesize = 0; - int sparsesize_nextitr = 0; - - int n_nodes = 1; - int n_nodes_nextitr = 1; - std::vector sparse_nodelist; - sparse_nodelist.reserve(tempmem->max_nodes_per_level); - sparse_nodelist.push_back(0); - //RNG setup - std::mt19937 mtg(treeid * 1000); - raft::random::Rng d_rng(treeid * 1000); - std::uniform_int_distribution dist(0, Ncols - 1); - - //Setup pointers - T* d_mseout = tempmem->d_mseout->data(); - T* h_mseout = tempmem->h_mseout->data(); - T* d_predout = tempmem->d_predout->data(); - T* h_predout = tempmem->h_predout->data(); - unsigned int* h_count = tempmem->h_count->data(); - unsigned int* d_count = tempmem->d_count->data(); - int* h_split_binidx = tempmem->h_split_binidx->data(); - int* d_split_binidx = tempmem->d_split_binidx->data(); - int* h_split_colidx = tempmem->h_split_colidx->data(); - int* d_split_colidx = tempmem->d_split_colidx->data(); - unsigned int* h_new_node_flags = tempmem->h_new_node_flags->data(); - unsigned int* d_new_node_flags = tempmem->d_new_node_flags->data(); - unsigned int* d_colids = tempmem->d_colids->data(); - unsigned int* h_colids = tempmem->h_colids->data(); - unsigned int* d_colstart = nullptr; - unsigned int* h_colstart = nullptr; - if (tempmem->d_colstart != nullptr) { - d_colstart = tempmem->d_colstart->data(); - h_colstart = tempmem->h_colstart->data(); - CUDA_CHECK(cudaMemsetAsync( - d_colstart, 0, tempmem->max_nodes_per_level * sizeof(unsigned int), - tempmem->stream)); - memset(h_colstart, 0, tempmem->max_nodes_per_level * sizeof(unsigned int)); - raft::update_device(d_colids, h_colids, Ncols, tempmem->stream); - } - std::vector feature_selector(h_colids, h_colids + Ncols); - float* infogain = tempmem->h_outgain->data(); - - int scatter_algo_depth = - std::min(tempmem->swap_depth, tree_params.max_depth + 1); - ML::PUSH_RANGE("scatter phase @levelfunc_regressor"); - for (int depth = 0; (depth < scatter_algo_depth) && (n_nodes_nextitr != 0); - depth++) { - depth_cnt = depth; - n_nodes = n_nodes_nextitr; - update_feature_sampling(h_colids, d_colids, h_colstart, d_colstart, Ncols, - ncols_sampled, n_nodes, mtg, dist, feature_selector, - tempmem, d_rng); - sparsesize = sparsesize_nextitr; - sparsesize_nextitr = sparsetree.size(); - - ASSERT( - n_nodes <= tempmem->max_nodes_per_level, - "Max node limit reached. Requested nodes %d > %d max nodes at depth %d\n", - n_nodes, tempmem->max_nodes_per_level, depth); - init_parent_value(sparse_meanstate, sparse_countstate, sparse_nodelist, - sparsesize, depth, tempmem); - - if (tree_params.split_criterion == ML::CRITERION::MSE) { - get_mse_regression_fused(data, labels, flagsptr, sample_cnt, nrows, - Ncols, ncols_sampled, tree_params.n_bins, - n_nodes, tree_params.split_algo, tempmem, - d_mseout, d_predout, d_count); - get_best_split_regression>( - h_mseout, d_mseout, h_predout, d_predout, h_count, d_count, h_colids, - d_colids, h_colstart, d_colstart, Ncols, ncols_sampled, - tree_params.n_bins, n_nodes, depth, tree_params.min_samples_leaf, - tree_params.split_algo, sparsesize, infogain, sparse_meanstate, - sparse_countstate, sparsetree, sparse_nodelist, h_split_colidx, - h_split_binidx, d_split_colidx, d_split_binidx, tempmem); - - } else { - get_mse_regression( - data, labels, flagsptr, sample_cnt, nrows, Ncols, ncols_sampled, - tree_params.n_bins, n_nodes, tree_params.split_algo, tempmem, d_mseout, - d_predout, d_count); - get_best_split_regression>( - h_mseout, d_mseout, h_predout, d_predout, h_count, d_count, h_colids, - d_colids, h_colstart, d_colstart, Ncols, ncols_sampled, - tree_params.n_bins, n_nodes, depth, tree_params.min_samples_leaf, - tree_params.split_algo, sparsesize, infogain, sparse_meanstate, - sparse_countstate, sparsetree, sparse_nodelist, h_split_colidx, - h_split_binidx, d_split_colidx, d_split_binidx, tempmem); - } - - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - leaf_eval_regression( - infogain, depth, tree_params.min_impurity_decrease, tree_params.max_depth, - tree_params.max_leaves, h_new_node_flags, sparsetree, sparsesize, - sparse_meanstate, n_nodes_nextitr, sparse_nodelist, leaf_cnt); - - raft::update_device(d_new_node_flags, h_new_node_flags, n_nodes, - tempmem->stream); - make_level_split(data, nrows, Ncols, ncols_sampled, tree_params.n_bins, - n_nodes, tree_params.split_algo, d_split_colidx, - d_split_binidx, d_new_node_flags, flagsptr, tempmem); - } - ML::POP_RANGE(); - ML::PUSH_RANGE("gather phase @levelfunc_regressor.cuh"); - // Start of gather algorithm - //Convertor - - int lastsize = sparsetree.size() - sparsesize_nextitr; - n_nodes = n_nodes_nextitr; - if (n_nodes == 0) { - ML::POP_RANGE(); // gather pahse ended - ML::POP_RANGE(); // grow_deep_tree_classification end - return; - } - unsigned int *d_nodecount, *d_samplelist, *d_nodestart; - SparseTreeNode* d_sparsenodes; - SparseTreeNode* h_sparsenodes; - int *h_nodelist, *d_nodelist, *d_new_nodelist; - int max_nodes = tempmem->max_nodes_per_level; - d_nodecount = (unsigned int*)(tempmem->d_child_best_metric->data()); - d_nodestart = (unsigned int*)(tempmem->d_split_binidx->data()); - d_samplelist = (unsigned int*)(tempmem->d_parent_metric->data()); - d_nodelist = (int*)(tempmem->d_outgain->data()); - d_new_nodelist = (int*)(tempmem->d_split_colidx->data()); - h_nodelist = (int*)(tempmem->h_outgain->data()); - d_sparsenodes = tempmem->d_sparsenodes->data(); - h_sparsenodes = tempmem->h_sparsenodes->data(); - - int* h_counter = tempmem->h_counter->data(); - int* d_counter = tempmem->d_counter->data(); - memcpy(h_nodelist, sparse_nodelist.data(), - sizeof(int) * sparse_nodelist.size()); - raft::update_device(d_nodelist, h_nodelist, sparse_nodelist.size(), - tempmem->stream); - //Resize to remove trailing nodes from previous algorithm - sparsetree.resize(sparsetree.size() - lastsize); - convert_scatter_to_gather(flagsptr, sample_cnt, n_nodes, nrows, d_nodecount, - d_nodestart, d_samplelist, tempmem); - if (tempmem->swap_depth == tree_params.max_depth) { - ++depth_cnt; - } - for (int depth = tempmem->swap_depth; - (depth < tree_params.max_depth) && (n_nodes != 0); depth++) { - depth_cnt = depth + 1; - //Algorithm starts here - update_feature_sampling(h_colids, d_colids, h_colstart, d_colstart, Ncols, - ncols_sampled, lastsize, mtg, dist, - feature_selector, tempmem, d_rng); - - best_split_gather_regression( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, nrows, - Ncols, ncols_sampled, tree_params.n_bins, n_nodes, tree_params.split_algo, - tree_params.split_criterion, sparsetree.size() + lastsize, - tree_params.min_impurity_decrease, tempmem, d_sparsenodes, d_nodelist); - - raft::update_host(h_sparsenodes, d_sparsenodes, lastsize, tempmem->stream); - //Update nodelist and split nodes - - make_split_gather(data, d_nodestart, d_samplelist, n_nodes, nrows, - d_nodelist, d_new_nodelist, d_nodecount, d_counter, - flagsptr, d_sparsenodes, tempmem); - CUDA_CHECK(cudaMemcpyAsync(d_nodelist, d_new_nodelist, - h_counter[0] * sizeof(int), - cudaMemcpyDeviceToDevice, tempmem->stream)); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - sparsetree.insert(sparsetree.end(), h_sparsenodes, - h_sparsenodes + lastsize); - lastsize = 2 * n_nodes; - n_nodes = h_counter[0]; - } - if (n_nodes != 0) { - make_leaf_gather_regression(labels, d_nodestart, d_samplelist, - d_sparsenodes, d_nodelist, n_nodes, tempmem); - raft::update_host(h_sparsenodes, d_sparsenodes, lastsize, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - sparsetree.insert(sparsetree.end(), h_sparsenodes, - h_sparsenodes + lastsize); - } - - ML::POP_RANGE(); // gather phase @levelfunc_regressor.cuh - - ML::POP_RANGE(); // grow_deep_tree_classification -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh b/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh deleted file mode 100644 index 3cd5b762d5..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh +++ /dev/null @@ -1,384 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include "levelkernel_classifier.cuh" - -#include - -namespace ML { -namespace DecisionTree { - -template -void initial_metric_classification( - const int *labels, unsigned int *sample_cnt, const int nrows, - const int n_unique_labels, std::vector &histvec, - T &initial_metric, std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTree::initial_metric_classification @levelhelper_classifier.cuh"); - CUDA_CHECK(cudaMemsetAsync(tempmem->d_parent_hist->data(), 0, - n_unique_labels * sizeof(unsigned int), - tempmem->stream)); - int blocks = raft::ceildiv(nrows, 128); - sample_count_histogram_kernel<<stream>>>( - labels, sample_cnt, nrows, n_unique_labels, - (int *)tempmem->d_parent_hist->data()); - CUDA_CHECK(cudaGetLastError()); - raft::update_host(tempmem->h_parent_hist->data(), - tempmem->d_parent_hist->data(), n_unique_labels, - tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - histvec.assign(tempmem->h_parent_hist->data(), - tempmem->h_parent_hist->data() + n_unique_labels); - initial_metric = F::exec(histvec, nrows); - ML::POP_RANGE(); -} - -template -void get_histogram_classification( - const T *data, const int *labels, unsigned int *flags, - unsigned int *sample_cnt, const int nrows, const int Ncols, - const int ncols_sampled, const int n_unique_labels, const int nbins, - const int n_nodes, const int split_algo, - std::shared_ptr> tempmem, unsigned int *histout) { - ML::PUSH_RANGE( - "DecisionTree::get_histogram_classification @levelhelper_classifier.cuh"); - - size_t histcount = ncols_sampled * nbins * n_unique_labels * n_nodes; - CUDA_CHECK(cudaMemsetAsync(histout, 0, histcount * sizeof(unsigned int), - tempmem->stream)); - int node_batch = min(n_nodes, tempmem->max_nodes_class); - size_t shmem = nbins * n_unique_labels * sizeof(int) * node_batch; - int threads = 256; - int blocks = raft::ceildiv(nrows, threads); - unsigned int *d_colstart = nullptr; - if (tempmem->d_colstart != nullptr) d_colstart = tempmem->d_colstart->data(); - if (split_algo == 0) { - get_minmax(data, flags, tempmem->d_colids->data(), d_colstart, nrows, Ncols, - ncols_sampled, n_nodes, tempmem->max_nodes_minmax, - tempmem->d_globalminmax->data(), tempmem->h_globalminmax->data(), - tempmem->stream); - if ((n_nodes == node_batch)) { - get_hist_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, n_unique_labels, nbins, - n_nodes, tempmem->d_globalminmax->data(), histout); - } else { - get_hist_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, n_unique_labels, nbins, - n_nodes, tempmem->d_globalminmax->data(), histout); - } - - } else { - if ((n_nodes == node_batch)) { - get_hist_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, n_unique_labels, nbins, - n_nodes, tempmem->d_quantile->data(), histout); - } else { - get_hist_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, n_unique_labels, nbins, - n_nodes, tempmem->d_quantile->data(), histout); - } - } - CUDA_CHECK(cudaGetLastError()); - ML::POP_RANGE(); -} -template -void get_best_split_classification( - unsigned int *hist, unsigned int *d_hist, unsigned int *h_colids, - unsigned int *d_colids, unsigned int *h_colstart, unsigned int *d_colstart, - const int Ncols, const int ncols_sampled, const int nbins, - const int n_unique_labels, const int n_nodes, const int depth, - const int min_rpn, const int split_algo, float *gain, - unsigned int *h_parent_hist, unsigned int *h_child_hist, - std::vector> &sparsetree, const int sparsesize, - std::vector &sparse_nodelist, int *split_colidx, int *split_binidx, - int *d_split_colidx, int *d_split_binidx, - std::shared_ptr> tempmem) { - ML::PUSH_RANGE("get_best_split_classification @levelhelper_classifier.cuh"); - T *quantile = nullptr; - T *minmax = nullptr; - if (tempmem->h_quantile != nullptr) quantile = tempmem->h_quantile->data(); - if (tempmem->h_globalminmax != nullptr) - minmax = tempmem->h_globalminmax->data(); - if (split_algo == 0) CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - bool use_gpu_flag = false; - size_t histcount = ncols_sampled * nbins * n_unique_labels * n_nodes; - if (n_nodes > 512) use_gpu_flag = true; - memset(gain, 0, n_nodes * sizeof(float)); - int sparsetree_sz = sparsetree.size(); - if (use_gpu_flag) { - //GPU based best split - unsigned int *d_parent_hist, *d_child_hist; - T *d_parent_metric, *d_child_best_metric; - T *h_parent_metric, *h_child_best_metric; - float *d_outgain, *h_outgain; - h_parent_metric = tempmem->h_parent_metric->data(); - h_child_best_metric = tempmem->h_child_best_metric->data(); - h_outgain = tempmem->h_outgain->data(); - - d_parent_hist = tempmem->d_parent_hist->data(); - d_child_hist = tempmem->d_child_hist->data(); - d_parent_metric = tempmem->d_parent_metric->data(); - d_child_best_metric = tempmem->d_child_best_metric->data(); - d_outgain = tempmem->d_outgain->data(); - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - int sparse_nodeid = sparse_nodelist[nodecnt]; - int parentid = sparsesize + sparse_nodeid; - unsigned int *parent_hist = - &h_parent_hist[sparse_nodeid * n_unique_labels]; - h_parent_metric[nodecnt] = sparsetree[parentid].best_metric_val; - memcpy(&h_parent_hist[nodecnt * n_unique_labels], parent_hist, - n_unique_labels * sizeof(int)); - } - - raft::update_device(d_parent_hist, h_parent_hist, n_nodes * n_unique_labels, - tempmem->stream); - raft::update_device(d_parent_metric, h_parent_metric, n_nodes, - tempmem->stream); - CUDA_CHECK( - cudaMemsetAsync(d_outgain, 0, n_nodes * sizeof(float), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_split_binidx, 0, n_nodes * sizeof(int), - tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_split_colidx, 0, n_nodes * sizeof(int), - tempmem->stream)); - - int threads = 64; - size_t shmemsz = (threads + 2) * 2 * n_unique_labels * sizeof(int); - get_best_split_classification_kernel - <<stream>>>( - d_hist, d_parent_hist, d_parent_metric, nbins, ncols_sampled, n_nodes, - n_unique_labels, min_rpn, d_outgain, d_split_colidx, d_split_binidx, - d_child_hist, d_child_best_metric); - CUDA_CHECK(cudaGetLastError()); - raft::update_host(h_child_hist, d_child_hist, 2 * n_nodes * n_unique_labels, - tempmem->stream); - raft::update_host(h_outgain, d_outgain, n_nodes, tempmem->stream); - raft::update_host(h_child_best_metric, d_child_best_metric, 2 * n_nodes, - tempmem->stream); - raft::update_host(split_binidx, d_split_binidx, n_nodes, tempmem->stream); - raft::update_host(split_colidx, d_split_colidx, n_nodes, tempmem->stream); - - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - int sparse_nodeid = sparse_nodelist[nodecnt]; - //Sparse tree - SparseTreeNode &curr_node = - sparsetree[sparsesize + sparse_nodeid]; - int local_colstart = -1; - if (h_colstart != nullptr) local_colstart = h_colstart[nodecnt]; - curr_node.colid = - getQuesColumn(h_colids, local_colstart, Ncols, ncols_sampled, - split_colidx[nodecnt], nodecnt); - curr_node.quesval = getQuesValue( - minmax, quantile, nbins, split_colidx[nodecnt], split_binidx[nodecnt], - nodecnt, n_nodes, curr_node.colid, split_algo); - - curr_node.left_child_id = sparsetree_sz + 2 * nodecnt; - SparseTreeNode leftnode, rightnode; - leftnode.best_metric_val = h_child_best_metric[2 * nodecnt]; - rightnode.best_metric_val = h_child_best_metric[2 * nodecnt + 1]; - sparsetree.push_back(leftnode); - sparsetree.push_back(rightnode); - } - } else { - raft::update_host(hist, d_hist, histcount, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - std::vector bestmetric(2, 0); - int nodeoffset = nodecnt * nbins * n_unique_labels; - int sparse_nodeid = sparse_nodelist[nodecnt]; - int parentid = sparsesize + sparse_nodeid; - int best_col_id = 0; - int best_bin_id = 0; - std::vector besthist_left(n_unique_labels, 0); - std::vector besthist_right(n_unique_labels, 0); - unsigned int *parent_hist = - &h_parent_hist[sparse_nodeid * n_unique_labels]; - for (int colid = 0; colid < ncols_sampled; colid++) { - int coloffset = colid * nbins * n_unique_labels * n_nodes; - for (int binid = 0; binid < nbins; binid++) { - int binoffset = binid * n_unique_labels; - int tmp_lnrows = 0; - int tmp_rnrows = 0; - std::vector tmp_histleft(n_unique_labels, 0); - std::vector tmp_histright(n_unique_labels, 0); - // Compute gini right and gini left value for each bin. - for (int j = 0; j < n_unique_labels; j++) { - tmp_histleft[j] = hist[coloffset + binoffset + nodeoffset + j]; - tmp_histright[j] = parent_hist[j] - tmp_histleft[j]; - } - for (int j = 0; j < n_unique_labels; j++) { - tmp_lnrows += tmp_histleft[j]; - tmp_rnrows += tmp_histright[j]; - } - int totalrows = tmp_lnrows + tmp_rnrows; - if (tmp_lnrows == 0 || tmp_rnrows == 0 || totalrows < min_rpn) - continue; - - float tmp_gini_left = F::exec(tmp_histleft, tmp_lnrows); - float tmp_gini_right = F::exec(tmp_histright, tmp_rnrows); - - float max_value = F::max_val(n_unique_labels); - - ASSERT((tmp_gini_left >= 0.0f) && (tmp_gini_left <= max_value), - "gini left value %f not in [0.0, %f]", tmp_gini_left, - max_value); - ASSERT((tmp_gini_right >= 0.0f) && (tmp_gini_right <= max_value), - "gini right value %f not in [0.0, %f]", tmp_gini_right, - max_value); - - float impurity = (tmp_lnrows * 1.0f / totalrows) * tmp_gini_left + - (tmp_rnrows * 1.0f / totalrows) * tmp_gini_right; - float info_gain = sparsetree[parentid].best_metric_val - impurity; - - // Compute best information col_gain so far - if (info_gain > gain[nodecnt]) { - gain[nodecnt] = info_gain; - best_bin_id = binid; - best_col_id = colid; - besthist_left = tmp_histleft; - besthist_right = tmp_histright; - bestmetric[0] = tmp_gini_left; - bestmetric[1] = tmp_gini_right; - } - } - } - split_colidx[nodecnt] = best_col_id; - split_binidx[nodecnt] = best_bin_id; - //Sparse tree - SparseTreeNode &curr_node = - sparsetree[sparsesize + sparse_nodeid]; - int local_colstart = -1; - if (h_colstart != nullptr) local_colstart = h_colstart[nodecnt]; - curr_node.colid = - getQuesColumn(h_colids, local_colstart, Ncols, ncols_sampled, - split_colidx[nodecnt], nodecnt); - curr_node.quesval = getQuesValue( - minmax, quantile, nbins, split_colidx[nodecnt], split_binidx[nodecnt], - nodecnt, n_nodes, curr_node.colid, split_algo); - - curr_node.left_child_id = sparsetree_sz + 2 * nodecnt; - SparseTreeNode leftnode, rightnode; - leftnode.best_metric_val = bestmetric[0]; - rightnode.best_metric_val = bestmetric[1]; - sparsetree.push_back(leftnode); - sparsetree.push_back(rightnode); - memcpy(&h_child_hist[2 * nodecnt * n_unique_labels], besthist_left.data(), - n_unique_labels * sizeof(unsigned int)); - memcpy(&h_child_hist[(2 * nodecnt + 1) * n_unique_labels], - besthist_right.data(), n_unique_labels * sizeof(unsigned int)); - } - raft::update_device(d_split_binidx, split_binidx, n_nodes, tempmem->stream); - raft::update_device(d_split_colidx, split_colidx, n_nodes, tempmem->stream); - } - ML::POP_RANGE(); -} - -template -void leaf_eval_classification( - float *gain, int curr_depth, const float min_impurity_decrease, - const int max_depth, const int n_unique_labels, const int max_leaves, - unsigned int *new_node_flags, std::vector> &sparsetree, - const int sparsesize, unsigned int *sparse_hist, int &n_nodes_next, - std::vector &sparse_nodelist, int &tree_leaf_cnt) { - std::vector tmp_sparse_nodelist(sparse_nodelist); - sparse_nodelist.clear(); - - int non_leaf_counter = 0; - // decide if the "next" layer of nodes are to be forcefully marked as leaves - bool condition_global = curr_depth >= max_depth; - if (max_leaves != -1) - condition_global = condition_global || (tree_leaf_cnt >= max_leaves); - - for (int i = 0; i < tmp_sparse_nodelist.size(); i++) { - unsigned int node_flag; - int sparse_nodeid = tmp_sparse_nodelist[i]; - unsigned int *nodehist = &sparse_hist[sparse_nodeid * n_unique_labels]; - bool condition = condition_global || (gain[i] <= min_impurity_decrease); - if (condition) { - node_flag = 0xFFFFFFFF; - sparsetree[sparsesize + sparse_nodeid].colid = -1; - sparsetree[sparsesize + sparse_nodeid].prediction = - get_class_hist(nodehist, n_unique_labels); - } else { - sparse_nodelist.push_back(2 * i); - sparse_nodelist.push_back(2 * i + 1); - node_flag = non_leaf_counter; - non_leaf_counter++; - } - new_node_flags[i] = node_flag; - } - int nleafed = tmp_sparse_nodelist.size() - non_leaf_counter; - tree_leaf_cnt += nleafed; - n_nodes_next = 2 * non_leaf_counter; -} - -template -void best_split_gather_classification( - const T *data, const int *labels, const unsigned int *d_colids, - const unsigned int *d_colstart, const unsigned int *d_nodestart, - const unsigned int *d_samplelist, const int nrows, const int Ncols, - const int ncols_sampled, const int n_unique_labels, const int nbins, - const int n_nodes, const int split_algo, const size_t treesz, - const float min_impurity_split, - std::shared_ptr> tempmem, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - const int TPB = TemporaryMemory::gather_threads; - if (split_algo == 0) { - using E = typename MLCommon::Stats::encode_traits::E; - T init_val = std::numeric_limits::max(); - size_t shmemsz = n_unique_labels * (nbins + 1) * sizeof(int); - best_split_gather_classification_minmax_kernel - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, n_nodes, - n_unique_labels, nbins, nrows, Ncols, ncols_sampled, treesz, - min_impurity_split, init_val, d_sparsenodes, d_nodelist); - } else { - const T *d_question_ptr = tempmem->d_quantile->data(); - size_t shmemsz = n_unique_labels * (nbins + 1) * sizeof(int); - best_split_gather_classification_kernel, FDEV, TPB> - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_question_ptr, d_nodestart, - d_samplelist, n_nodes, n_unique_labels, nbins, nrows, Ncols, - ncols_sampled, treesz, min_impurity_split, d_sparsenodes, d_nodelist); - } - CUDA_CHECK(cudaGetLastError()); -} -template -void make_leaf_gather_classification( - const int *labels, const unsigned int *nodestart, - const unsigned int *samplelist, const int n_unique_labels, - SparseTreeNode *d_sparsenodes, int *nodelist, const int n_nodes, - std::shared_ptr> tempmem) { - size_t shmemsz = n_unique_labels * sizeof(int); - make_leaf_gather_classification_kernel - <<gather_threads, shmemsz, tempmem->stream>>>( - labels, nodestart, samplelist, n_unique_labels, d_sparsenodes, nodelist); - CUDA_CHECK(cudaGetLastError()); -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelhelper_regressor.cuh b/cpp/src/decisiontree/levelalgo/levelhelper_regressor.cuh deleted file mode 100644 index b0ee921bb4..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelhelper_regressor.cuh +++ /dev/null @@ -1,539 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include "levelkernel_regressor.cuh" - -namespace ML { -namespace DecisionTree { - -template -void initial_metric_regression(const T *labels, unsigned int *sample_cnt, - const int nrows, T &mean, unsigned int &count, - T &initial_metric, - std::shared_ptr> tempmem) { - ML::PUSH_RANGE( - "DecisionTree::initial_metric_classification @levelhelper_regressor.cuh"); - CUDA_CHECK( - cudaMemsetAsync(tempmem->d_mseout->data(), 0, sizeof(T), tempmem->stream)); - CUDA_CHECK( - cudaMemsetAsync(tempmem->d_predout->data(), 0, sizeof(T), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(tempmem->d_count->data(), 0, sizeof(unsigned int), - tempmem->stream)); - int threads = 128; - int blocks = raft::ceildiv(nrows, threads); - - pred_kernel_level<<stream>>>( - labels, sample_cnt, nrows, tempmem->d_predout->data(), - tempmem->d_count->data()); - CUDA_CHECK(cudaGetLastError()); - mse_kernel_level<<stream>>>( - labels, sample_cnt, nrows, tempmem->d_predout->data(), - tempmem->d_count->data(), tempmem->d_mseout->data()); - CUDA_CHECK(cudaGetLastError()); - raft::update_host(tempmem->h_count->data(), tempmem->d_count->data(), 1, - tempmem->stream); - raft::update_host(tempmem->h_predout->data(), tempmem->d_predout->data(), 1, - tempmem->stream); - raft::update_host(tempmem->h_mseout->data(), tempmem->d_mseout->data(), 1, - tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - count = tempmem->h_count->data()[0]; - mean = tempmem->h_predout->data()[0] / count; - initial_metric = tempmem->h_mseout->data()[0] / count; - ML::POP_RANGE(); -} - -template -void get_mse_regression_fused(const T *data, const T *labels, - unsigned int *flags, unsigned int *sample_cnt, - const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, - const int n_nodes, const int split_algo, - std::shared_ptr> tempmem, - T *d_mseout, T *d_predout, - unsigned int *d_count) { - ML::PUSH_RANGE( - "DecisionTree::get_mse_regression_fused @levelhelper_regressor.cuh"); - size_t predcount = ncols_sampled * nbins * n_nodes; - CUDA_CHECK( - cudaMemsetAsync(d_mseout, 0, 2 * predcount * sizeof(T), tempmem->stream)); - CUDA_CHECK( - cudaMemsetAsync(d_predout, 0, predcount * sizeof(T), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_count, 0, predcount * sizeof(unsigned int), - tempmem->stream)); - - int node_batch_pred = min(n_nodes, tempmem->max_nodes_pred); - size_t shmempred = nbins * (sizeof(unsigned int) + sizeof(T)) * n_nodes; - - int threads = 256; - int blocks = raft::ceildiv(nrows, threads); - unsigned int *d_colstart = nullptr; - if (tempmem->d_colstart != nullptr) d_colstart = tempmem->d_colstart->data(); - - if (split_algo == 0) { - get_minmax(data, flags, tempmem->d_colids->data(), d_colstart, nrows, Ncols, - ncols_sampled, n_nodes, tempmem->max_nodes_minmax, - tempmem->d_globalminmax->data(), tempmem->h_globalminmax->data(), - tempmem->stream); - if ((n_nodes == node_batch_pred)) { - get_pred_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), d_predout, d_count); - } else { - get_pred_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), d_predout, d_count); - } - CUDA_CHECK(cudaGetLastError()); - } else { - if ((n_nodes == node_batch_pred)) { - get_pred_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), d_predout, d_count); - } else { - get_pred_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), d_predout, d_count); - } - CUDA_CHECK(cudaGetLastError()); - } - ML::POP_RANGE(); -} -template -void get_mse_regression(const T *data, const T *labels, unsigned int *flags, - unsigned int *sample_cnt, const int nrows, - const int Ncols, const int ncols_sampled, - const int nbins, const int n_nodes, - const int split_algo, - std::shared_ptr> tempmem, - T *d_mseout, T *d_predout, unsigned int *d_count) { - ML::PUSH_RANGE("DecisionTree::get_mse_regression @levelhelper_regressor.cuh"); - size_t predcount = ncols_sampled * nbins * n_nodes; - CUDA_CHECK( - cudaMemsetAsync(d_mseout, 0, 2 * predcount * sizeof(T), tempmem->stream)); - CUDA_CHECK( - cudaMemsetAsync(d_predout, 0, predcount * sizeof(T), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_count, 0, predcount * sizeof(unsigned int), - tempmem->stream)); - - int node_batch_pred = min(n_nodes, tempmem->max_nodes_pred); - int node_batch_mse = min(n_nodes, tempmem->max_nodes_mse); - size_t shmempred = nbins * (sizeof(unsigned int) + sizeof(T)) * n_nodes; - size_t shmemmse = shmempred + 2 * nbins * n_nodes * sizeof(T); - - int threads = 256; - int blocks = raft::ceildiv(nrows, threads); - unsigned int *d_colstart = nullptr; - if (tempmem->d_colstart != nullptr) d_colstart = tempmem->d_colstart->data(); - - if (split_algo == ML::SPLIT_ALGO::HIST) { - get_minmax(data, flags, tempmem->d_colids->data(), d_colstart, nrows, Ncols, - ncols_sampled, n_nodes, tempmem->max_nodes_minmax, - tempmem->d_globalminmax->data(), tempmem->h_globalminmax->data(), - tempmem->stream); - if ((n_nodes == node_batch_pred)) { - get_pred_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), d_predout, d_count); - } else { - get_pred_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), d_predout, d_count); - } - CUDA_CHECK(cudaGetLastError()); - if ((n_nodes == node_batch_mse)) { - get_mse_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), tempmem->d_parent_pred->data(), - tempmem->d_parent_count->data(), d_predout, d_count, d_mseout); - } else { - get_mse_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_globalminmax->data(), tempmem->d_parent_pred->data(), - tempmem->d_parent_count->data(), d_predout, d_count, d_mseout); - } - CUDA_CHECK(cudaGetLastError()); - - } else { - if ((n_nodes == node_batch_pred)) { - get_pred_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), d_predout, d_count); - } else { - get_pred_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), d_predout, d_count); - } - CUDA_CHECK(cudaGetLastError()); - if ((n_nodes == node_batch_mse)) { - get_mse_kernel> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), tempmem->d_parent_pred->data(), - tempmem->d_parent_count->data(), d_predout, d_count, d_mseout); - } else { - get_mse_kernel_global> - <<stream>>>( - data, labels, flags, sample_cnt, tempmem->d_colids->data(), - d_colstart, nrows, Ncols, ncols_sampled, nbins, n_nodes, - tempmem->d_quantile->data(), tempmem->d_parent_pred->data(), - tempmem->d_parent_count->data(), d_predout, d_count, d_mseout); - } - CUDA_CHECK(cudaGetLastError()); - } - ML::POP_RANGE(); -} -template -void get_best_split_regression( - T *mseout, T *d_mseout, T *predout, T *d_predout, unsigned int *count, - unsigned int *d_count, unsigned int *h_colids, unsigned int *d_colids, - unsigned int *h_colstart, unsigned int *d_colstart, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, const int depth, - const int min_rpn, const int split_algo, const int sparsesize, float *gain, - std::vector &sparse_meanstate, - std::vector &sparse_countstate, - std::vector> &sparsetree, - std::vector &sparse_nodelist, int *split_colidx, int *split_binidx, - int *d_split_colidx, int *d_split_binidx, - std::shared_ptr> tempmem) { - ML::PUSH_RANGE("get_best_split_regression @levelhelper_regressor.cuh"); - T *quantile = nullptr; - T *minmax = nullptr; - if (tempmem->h_quantile != nullptr) quantile = tempmem->h_quantile->data(); - if (tempmem->h_globalminmax != nullptr) - minmax = tempmem->h_globalminmax->data(); - - size_t predcount = ncols_sampled * nbins * n_nodes; - bool use_gpu_flag = false; - if (n_nodes > 512) use_gpu_flag = true; - - memset(gain, 0, n_nodes * sizeof(float)); - int sparsetree_sz = sparsetree.size(); - if (use_gpu_flag) { - int threads = 64; - - T *h_parentmetric = tempmem->h_parent_metric->data(); - float *h_outgain = tempmem->h_outgain->data(); - T *h_childmean = tempmem->h_child_pred->data(); - unsigned int *h_childcount = tempmem->h_child_count->data(); - T *h_childmetric = tempmem->h_child_best_metric->data(); - - T *d_parentmean = tempmem->d_parent_pred->data(); - unsigned int *d_parentcount = tempmem->d_parent_count->data(); - T *d_parentmetric = tempmem->d_parent_metric->data(); - float *d_outgain = tempmem->d_outgain->data(); - T *d_childmean = tempmem->d_child_pred->data(); - unsigned int *d_childcount = tempmem->d_child_count->data(); - T *d_childmetric = tempmem->d_child_best_metric->data(); - - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - int sparse_nodeid = sparse_nodelist[nodecnt]; - h_parentmetric[nodecnt] = - sparsetree[sparsesize + sparse_nodeid].best_metric_val; - } - - //Here parent mean and count are already updated - raft::update_device(d_parentmetric, h_parentmetric, n_nodes, - tempmem->stream); - CUDA_CHECK( - cudaMemsetAsync(d_outgain, 0, n_nodes * sizeof(float), tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_split_binidx, 0, n_nodes * sizeof(int), - tempmem->stream)); - CUDA_CHECK(cudaMemsetAsync(d_split_colidx, 0, n_nodes * sizeof(int), - tempmem->stream)); - - get_best_split_regression_kernel - <<stream>>>( - d_mseout, d_predout, d_count, d_parentmean, d_parentcount, - d_parentmetric, nbins, ncols_sampled, n_nodes, min_rpn, d_outgain, - d_split_colidx, d_split_binidx, d_childmean, d_childcount, - d_childmetric); - CUDA_CHECK(cudaGetLastError()); - - raft::update_host(h_childmetric, d_childmetric, 2 * n_nodes, - tempmem->stream); - raft::update_host(h_outgain, d_outgain, n_nodes, tempmem->stream); - raft::update_host(h_childmean, d_childmean, 2 * n_nodes, tempmem->stream); - raft::update_host(h_childcount, d_childcount, 2 * n_nodes, tempmem->stream); - raft::update_host(split_binidx, d_split_binidx, n_nodes, tempmem->stream); - raft::update_host(split_colidx, d_split_colidx, n_nodes, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - int sparse_nodeid = sparse_nodelist[nodecnt]; - SparseTreeNode &curr_node = sparsetree[sparsesize + sparse_nodeid]; - int local_colstart = -1; - if (h_colstart != nullptr) local_colstart = h_colstart[nodecnt]; - curr_node.colid = - getQuesColumn(h_colids, local_colstart, Ncols, ncols_sampled, - split_colidx[nodecnt], nodecnt); - curr_node.quesval = getQuesValue( - minmax, quantile, nbins, split_colidx[nodecnt], split_binidx[nodecnt], - nodecnt, n_nodes, curr_node.colid, split_algo); - - curr_node.left_child_id = sparsetree_sz + 2 * nodecnt; - sparse_meanstate[curr_node.left_child_id] = h_childmean[nodecnt * 2]; - sparse_meanstate[curr_node.left_child_id + 1] = - h_childmean[nodecnt * 2 + 1]; - sparse_countstate[curr_node.left_child_id] = h_childcount[nodecnt * 2]; - sparse_countstate[curr_node.left_child_id + 1] = - h_childcount[nodecnt * 2 + 1]; - SparseTreeNode leftnode, rightnode; - leftnode.best_metric_val = h_childmetric[nodecnt * 2]; - rightnode.best_metric_val = h_childmetric[nodecnt * 2 + 1]; - sparsetree.push_back(leftnode); - sparsetree.push_back(rightnode); - } - - } else { - raft::update_host(mseout, d_mseout, 2 * predcount, tempmem->stream); - raft::update_host(predout, d_predout, predcount, tempmem->stream); - raft::update_host(count, d_count, predcount, tempmem->stream); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - for (int nodecnt = 0; nodecnt < n_nodes; nodecnt++) { - T bestmetric_left = 0; - T bestmetric_right = 0; - int nodeoff_mse = nodecnt * nbins * 2; - int nodeoff_pred = nodecnt * nbins; - int sparse_nodeid = sparse_nodelist[nodecnt]; - int parentid = sparse_nodeid + sparsesize; - int best_col_id = 0; - int best_bin_id = 0; - T bestmean_left = 0; - T bestmean_right = 0; - unsigned int bestcount_left = 0; - unsigned int bestcount_right = 0; - T parent_mean = sparse_meanstate[parentid]; - unsigned int parent_count = sparse_countstate[parentid]; - for (int colid = 0; colid < ncols_sampled; colid++) { - int coloff_mse = colid * nbins * 2 * n_nodes; - int coloff_pred = colid * nbins * n_nodes; - for (int binid = 0; binid < nbins; binid++) { - int binoff_mse = binid * 2; - int binoff_pred = binid; - unsigned int tmp_lnrows = 0; - unsigned int tmp_rnrows = 0; - - tmp_lnrows = count[coloff_pred + binoff_pred + nodeoff_pred]; - tmp_rnrows = parent_count - tmp_lnrows; - unsigned int totalrows = tmp_lnrows + tmp_rnrows; - if (tmp_lnrows == 0 || tmp_rnrows == 0 || totalrows < min_rpn) - continue; - T tmp_meanleft = predout[coloff_pred + binoff_pred + nodeoff_pred]; - T tmp_meanright = parent_mean * parent_count - tmp_meanleft; - T tmp_mse_left = mseout[coloff_mse + binoff_mse + nodeoff_mse]; - T tmp_mse_right = mseout[coloff_mse + binoff_mse + nodeoff_mse + 1]; - - float info_gain = (float)Gain::exec( - sparsetree[parentid].best_metric_val, parent_count, tmp_lnrows, - tmp_rnrows, parent_mean, tmp_meanleft, tmp_meanright, tmp_mse_left, - tmp_mse_right); - // Compute best information col_gain so far - if (info_gain > gain[nodecnt]) { - gain[nodecnt] = info_gain; - best_bin_id = binid; - best_col_id = colid; - bestmean_left = tmp_meanleft; - bestmean_right = tmp_meanright; - bestcount_left = tmp_lnrows; - bestcount_right = tmp_rnrows; - bestmetric_left = tmp_mse_left; - bestmetric_right = tmp_mse_right; - } - } - } - split_colidx[nodecnt] = best_col_id; - split_binidx[nodecnt] = best_bin_id; - //Sparse Tree - SparseTreeNode &curr_node = sparsetree[sparsesize + sparse_nodeid]; - int local_colstart = -1; - if (h_colstart != nullptr) local_colstart = h_colstart[nodecnt]; - curr_node.colid = - getQuesColumn(h_colids, local_colstart, Ncols, ncols_sampled, - split_colidx[nodecnt], nodecnt); - curr_node.quesval = getQuesValue( - minmax, quantile, nbins, split_colidx[nodecnt], split_binidx[nodecnt], - nodecnt, n_nodes, curr_node.colid, split_algo); - - curr_node.left_child_id = sparsetree_sz + 2 * nodecnt; - sparse_meanstate[curr_node.left_child_id] = bestmean_left; - sparse_meanstate[curr_node.left_child_id + 1] = bestmean_right; - sparse_countstate[curr_node.left_child_id] = bestcount_left; - sparse_countstate[curr_node.left_child_id + 1] = bestcount_right; - SparseTreeNode leftnode, rightnode; - leftnode.best_metric_val = bestmetric_left; - rightnode.best_metric_val = bestmetric_right; - sparsetree.push_back(leftnode); - sparsetree.push_back(rightnode); - } - raft::update_device(d_split_binidx, split_binidx, n_nodes, tempmem->stream); - raft::update_device(d_split_colidx, split_colidx, n_nodes, tempmem->stream); - } - ML::POP_RANGE(); -} - -template -void leaf_eval_regression(float *gain, int curr_depth, - const float min_impurity_decrease, - const int max_depth, const int max_leaves, - unsigned int *new_node_flags, - std::vector> &sparsetree, - const int sparsesize, std::vector &sparse_mean, - int &n_nodes_next, std::vector &sparse_nodelist, - int &tree_leaf_cnt) { - ML::PUSH_RANGE("leaf_eval_regression @levelhelper_regressor.cuh"); - std::vector tmp_sparse_nodelist(sparse_nodelist); - sparse_nodelist.clear(); - - int non_leaf_counter = 0; - bool condition_global = (curr_depth >= max_depth); - if (max_leaves != -1) - condition_global = condition_global || (tree_leaf_cnt >= max_leaves); - - for (int i = 0; i < tmp_sparse_nodelist.size(); i++) { - unsigned int node_flag; - int sparse_nodeid = tmp_sparse_nodelist[i]; - T nodemean = sparse_mean[sparsesize + sparse_nodeid]; - bool condition = condition_global || (gain[i] <= min_impurity_decrease); - if (condition) { - node_flag = 0xFFFFFFFF; - sparsetree[sparsesize + sparse_nodeid].colid = -1; - sparsetree[sparsesize + sparse_nodeid].prediction = nodemean; - } else { - sparse_nodelist.push_back(2 * i); - sparse_nodelist.push_back(2 * i + 1); - node_flag = non_leaf_counter; - non_leaf_counter++; - } - new_node_flags[i] = node_flag; - } - int nleafed = tmp_sparse_nodelist.size() - non_leaf_counter; - tree_leaf_cnt += nleafed; - n_nodes_next = 2 * non_leaf_counter; - ML::POP_RANGE(); -} - -template -void init_parent_value(std::vector &sparse_meanstate, - std::vector &sparse_countstate, - std::vector &sparse_nodelist, const int sparsesize, - const int depth, - std::shared_ptr> tempmem) { - T *h_predout = tempmem->h_predout->data(); - unsigned int *h_count = tempmem->h_count->data(); - int n_nodes = sparse_nodelist.size(); - for (int i = 0; i < n_nodes; i++) { - int sparse_nodeid = sparse_nodelist[i]; - h_predout[i] = sparse_meanstate[sparsesize + sparse_nodeid]; - h_count[i] = sparse_countstate[sparsesize + sparse_nodeid]; - } - raft::update_device(tempmem->d_parent_pred->data(), h_predout, n_nodes, - tempmem->stream); - raft::update_device(tempmem->d_parent_count->data(), h_count, n_nodes, - tempmem->stream); -} - -template -void best_split_gather_regression( - const T *data, const T *labels, const unsigned int *d_colids, - const unsigned int *d_colstart, const unsigned int *d_nodestart, - const unsigned int *d_samplelist, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const int split_algo, const ML::CRITERION split_cr, const size_t treesz, - const float min_impurity_split, - std::shared_ptr> tempmem, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - ML::PUSH_RANGE("get_best_split_gather_regression @levelhelper_regressor.cuh"); - const int TPB = TemporaryMemory::gather_threads; - if (split_cr == ML::CRITERION::MSE) { - if (split_algo == ML::SPLIT_ALGO::HIST) { - using E = typename MLCommon::Stats::encode_traits::E; - T init_val = std::numeric_limits::max(); - size_t shmemsz = nbins * sizeof(int) + nbins * sizeof(T); - best_split_gather_regression_mse_minmax_kernel - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, - n_nodes, nbins, nrows, Ncols, ncols_sampled, treesz, - min_impurity_split, init_val, d_sparsenodes, d_nodelist); - } else { - const T *d_question_ptr = tempmem->d_quantile->data(); - size_t shmemsz = nbins * sizeof(T) + nbins * sizeof(int); - best_split_gather_regression_mse_kernel, TPB> - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_question_ptr, d_nodestart, - d_samplelist, n_nodes, nbins, nrows, Ncols, ncols_sampled, treesz, - min_impurity_split, d_sparsenodes, d_nodelist); - } - CUDA_CHECK(cudaGetLastError()); - } else { - if (split_algo == 0) { - using E = typename MLCommon::Stats::encode_traits::E; - T init_val = std::numeric_limits::max(); - size_t shmemsz = 3 * nbins * sizeof(int) + nbins * sizeof(T); - best_split_gather_regression_mae_minmax_kernel - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_nodestart, d_samplelist, - n_nodes, nbins, nrows, Ncols, ncols_sampled, treesz, - min_impurity_split, init_val, d_sparsenodes, d_nodelist); - } else { - const T *d_question_ptr = tempmem->d_quantile->data(); - size_t shmemsz = 3 * nbins * sizeof(T) + nbins * sizeof(int); - best_split_gather_regression_mae_kernel, TPB> - <<gather_threads, shmemsz, tempmem->stream>>>( - data, labels, d_colids, d_colstart, d_question_ptr, d_nodestart, - d_samplelist, n_nodes, nbins, nrows, Ncols, ncols_sampled, treesz, - min_impurity_split, d_sparsenodes, d_nodelist); - } - CUDA_CHECK(cudaGetLastError()); - } - ML::POP_RANGE(); -} -template -void make_leaf_gather_regression( - const T *labels, const unsigned int *nodestart, - const unsigned int *samplelist, SparseTreeNode *d_sparsenodes, - int *nodelist, const int n_nodes, - std::shared_ptr> tempmem) { - make_leaf_gather_regression_kernel<<gather_threads, 0, - tempmem->stream>>>( - labels, nodestart, samplelist, d_sparsenodes, nodelist); - CUDA_CHECK(cudaGetLastError()); -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelkernel_classifier.cuh b/cpp/src/decisiontree/levelalgo/levelkernel_classifier.cuh deleted file mode 100644 index e783480604..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelkernel_classifier.cuh +++ /dev/null @@ -1,620 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include "common_kernel.cuh" - -namespace ML { -namespace DecisionTree { - -__global__ void sample_count_histogram_kernel( - const int* __restrict__ labels, const unsigned int* __restrict__ sample_cnt, - const int nrows, const int nmax, int* histout) { - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ unsigned int shmemhist[]; - for (int tid = threadIdx.x; tid < nmax; tid += blockDim.x) { - shmemhist[tid] = 0; - } - - __syncthreads(); - - for (int tid = threadid; tid < nrows; tid += blockDim.x * gridDim.x) { - int label = labels[tid]; - int count = sample_cnt[tid]; - raft::myAtomicAdd(&shmemhist[label], count); - } - - __syncthreads(); - - for (int tid = threadIdx.x; tid < nmax; tid += blockDim.x) { - raft::myAtomicAdd((unsigned int*)&histout[tid], - shmemhist[tid]); - } - return; -} - -//This kernel does histograms for all bins, all cols and all nodes at a given level -template -__global__ void get_hist_kernel( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ flags, - const unsigned int* __restrict__ sample_cnt, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int n_unique_labels, const int nbins, - const int n_nodes, const T* __restrict__ question_ptr, - unsigned int* histout) { - extern __shared__ unsigned int shmemhist[]; - unsigned int local_flag = LEAF; - int local_label = -1; - int local_cnt; - int tid = threadIdx.x + blockIdx.x * blockDim.x; - unsigned int colid; - int colstart_local = -1; - if (tid < nrows) { - local_flag = flags[tid]; - } - if (local_flag != LEAF) { - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - } - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (local_flag != LEAF) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - colcnt, local_flag); - } - for (unsigned int i = threadIdx.x; i < nbins * n_nodes * n_unique_labels; - i += blockDim.x) { - shmemhist[i] = 0; - } - __syncthreads(); - - //Check if leaf - if (local_flag != LEAF) { - T local_data = data[tid + colid * nrows]; - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - unsigned int nodeoff = local_flag * nbins * n_unique_labels; - raft::myAtomicAdd( - &shmemhist[nodeoff + binid * n_unique_labels + local_label], - local_cnt); - } - } - } - - __syncthreads(); - for (unsigned int i = threadIdx.x; i < nbins * n_nodes * n_unique_labels; - i += blockDim.x) { - unsigned int offset = colcnt * nbins * n_nodes * n_unique_labels; - raft::myAtomicAdd(&histout[offset + i], shmemhist[i]); - } - __syncthreads(); - } -} - -/*This kernel does histograms for all bins, all cols and all nodes at a given level - *when nodes cannot fit in shared memory. We use direct global atomics; - *as this will be faster than shared memory loop due to reduced conjetion for atomics - */ -template -__global__ void get_hist_kernel_global( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ flags, - const unsigned int* __restrict__ sample_cnt, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int n_unique_labels, const int nbins, - const int n_nodes, const T* __restrict__ question_ptr, - unsigned int* histout) { - unsigned int local_flag; - int local_label; - int local_cnt; - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - for (int tid = threadid; tid < nrows; tid += gridDim.x * blockDim.x) { - local_flag = flags[tid]; - if (local_flag != LEAF) { - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - int colstart_local = -1; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - unsigned int colid = get_column_id(colids, colstart_local, Ncols, - ncols_sampled, colcnt, local_flag); - T local_data = data[tid + colid * nrows]; - //Loop over nbins - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); - -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - unsigned int coloff = colcnt * nbins * n_nodes * n_unique_labels; - unsigned int nodeoff = local_flag * nbins * n_unique_labels; - raft::myAtomicAdd( - &histout[coloff + nodeoff + binid * n_unique_labels + - local_label], - local_cnt); - } - } - } - } - } -} - -struct GiniDevFunctor { - static DI float exec(unsigned int* hist, int nrows, int n_unique_labels) { - float gval = 1.0; - for (int i = 0; i < n_unique_labels; i++) { - float prob = ((float)hist[i]) / nrows; - gval -= prob * prob; - } - return gval; - } - static DI void execshared(const unsigned int* hist, float* metric, - const int nrows, const int n_unique_labels) { - auto& tid = threadIdx.x; - if (tid == 0) metric[0] = 1.0; - __syncthreads(); - if (tid < n_unique_labels) { - float prob = ((float)hist[tid]) / nrows; - prob = -1 * prob * prob; - raft::myAtomicAdd(metric, prob); - } - __syncthreads(); - } -}; - -struct EntropyDevFunctor { - static DI float exec(unsigned int* hist, int nrows, int n_unique_labels) { - float eval = 0.0; - for (int i = 0; i < n_unique_labels; i++) { - if (hist[i] != 0) { - float prob = ((float)hist[i]) / nrows; - eval += prob * logf(prob); - } - } - return (-1 * eval); - } - static DI void execshared(const unsigned int* hist, float* metric, - const int nrows, const int n_unique_labels) { - auto& tid = threadIdx.x; - if (tid == 0) metric[0] = 0.0; - __syncthreads(); - if (tid < n_unique_labels) { - if (hist[tid] != 0) { - float prob = ((float)hist[tid]) / nrows; - prob = -1 * prob * logf(prob); - raft::myAtomicAdd(metric, prob); - } - } - __syncthreads(); - } -}; -//This is device equialent of best split finding reduction. -//Only kicks in when number of node is more than 512. otherwise we use CPU. -template -__global__ void get_best_split_classification_kernel( - const unsigned int* __restrict__ hist, - const unsigned int* __restrict__ parent_hist, - const T* __restrict__ parent_metric, const int nbins, const int ncols_sampled, - const int n_nodes, const int n_unique_labels, const int min_rpn, - float* outgain, int* best_col_id, int* best_bin_id, unsigned int* child_hist, - T* child_best_metric) { - extern __shared__ unsigned int shmem_split_eval[]; - __shared__ int best_nrows[2]; - __shared__ GainIdxPair shared_pair; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - unsigned int* tmp_histleft = &shmem_split_eval[threadIdx.x * n_unique_labels]; - unsigned int* tmp_histright = - &shmem_split_eval[threadIdx.x * n_unique_labels + - blockDim.x * n_unique_labels]; - unsigned int* best_split_hist = - &shmem_split_eval[2 * n_unique_labels * blockDim.x]; - unsigned int* parent_hist_local = - &shmem_split_eval[2 * n_unique_labels * (blockDim.x + 1)]; - - for (unsigned int nodeid = blockIdx.x; nodeid < n_nodes; - nodeid += gridDim.x) { - if (threadIdx.x < 2) { - best_nrows[threadIdx.x] = 0; - } - - int nodeoffset = nodeid * nbins * n_unique_labels; - float parent_metric_local = parent_metric[nodeid]; - - for (int j = threadIdx.x; j < n_unique_labels; j += blockDim.x) { - parent_hist_local[j] = parent_hist[nodeid * n_unique_labels + j]; - } - - __syncthreads(); - - GainIdxPair tid_pair; - tid_pair.gain = 0.0; - tid_pair.idx = -1; - for (int id = threadIdx.x; id < nbins * ncols_sampled; id += blockDim.x) { - int coloffset = ((int)(id / nbins)) * nbins * n_unique_labels * n_nodes; - int binoffset = (id % nbins) * n_unique_labels; - int tmp_lnrows = 0; - int tmp_rnrows = 0; - for (int j = 0; j < n_unique_labels; j++) { - tmp_histleft[j] = hist[coloffset + binoffset + nodeoffset + j]; - tmp_lnrows += tmp_histleft[j]; - tmp_histright[j] = parent_hist_local[j] - tmp_histleft[j]; - tmp_rnrows += tmp_histright[j]; - } - - int totalrows = tmp_lnrows + tmp_rnrows; - if (tmp_lnrows == 0 || tmp_rnrows == 0 || totalrows < min_rpn) continue; - - float tmp_gini_left = F::exec(tmp_histleft, tmp_lnrows, n_unique_labels); - float tmp_gini_right = - F::exec(tmp_histright, tmp_rnrows, n_unique_labels); - - float impurity = (tmp_lnrows * 1.0f / totalrows) * tmp_gini_left + - (tmp_rnrows * 1.0f / totalrows) * tmp_gini_right; - float info_gain = parent_metric_local - impurity; - if (info_gain > tid_pair.gain) { - tid_pair.gain = info_gain; - tid_pair.idx = id; - } - } - __syncthreads(); - GainIdxPair ans = - BlockReduce(temp_storage).Reduce(tid_pair, ReducePair()); - - if (threadIdx.x == 0) { - shared_pair = ans; - } - __syncthreads(); - ans = shared_pair; - - if (ans.idx != -1) { - if (threadIdx.x == (blockDim.x - 1)) { - outgain[nodeid] = ans.gain; - best_col_id[nodeid] = (int)(ans.idx / nbins); - best_bin_id[nodeid] = ans.idx % nbins; - } - - int coloffset = - ((int)(ans.idx / nbins)) * nbins * n_unique_labels * n_nodes; - int binoffset = (ans.idx % nbins) * n_unique_labels; - - for (int j = threadIdx.x; j < n_unique_labels; j += blockDim.x) { - unsigned int val_left = hist[coloffset + binoffset + nodeoffset + j]; - unsigned int val_right = parent_hist_local[j] - val_left; - best_split_hist[j] = val_left; - raft::myAtomicAdd((unsigned int*)&best_nrows[0], - val_left); - best_split_hist[j + n_unique_labels] = val_right; - raft::myAtomicAdd((unsigned int*)&best_nrows[1], - val_right); - } - __syncthreads(); - - for (int j = threadIdx.x; j < 2 * n_unique_labels; j += blockDim.x) { - child_hist[2 * n_unique_labels * nodeid + j] = best_split_hist[j]; - } - - if (threadIdx.x < 2) { - child_best_metric[2 * nodeid + threadIdx.x] = - F::exec(&best_split_hist[threadIdx.x * n_unique_labels], - best_nrows[threadIdx.x], n_unique_labels); - } - } - } -} - -template -DI GainIdxPair bin_info_gain_classification( - const unsigned int* shmemhist_parent, const float* parent_metric, - unsigned int* shmemhist_left, const int nsamples, const int nbins, - const int n_unique_labels) { - GainIdxPair tid_pair; - tid_pair.gain = 0.0; - tid_pair.idx = -1; - for (int tid = threadIdx.x; tid < nbins; tid += blockDim.x) { - int nrows_left = 0; - unsigned int* shmemhist = shmemhist_left + tid * n_unique_labels; - for (int i = 0; i < n_unique_labels; i++) { - nrows_left += shmemhist[i]; - } - if ((nrows_left != nsamples) && (nrows_left != 0)) { - int nrows_right = nsamples - nrows_left; - float left_metric = F::exec(shmemhist, nrows_left, n_unique_labels); - for (int i = 0; i < n_unique_labels; i++) { - shmemhist[i] = shmemhist_parent[i] - shmemhist[i]; - } - float right_metric = F::exec(shmemhist, nrows_right, n_unique_labels); - float impurity = ((nrows_left * 1.0f) / nsamples) * left_metric + - ((nrows_right * 1.0f) / nsamples) * right_metric; - float info_gain = parent_metric[0] - impurity; - if (info_gain > tid_pair.gain) { - tid_pair.gain = info_gain; - tid_pair.idx = tid; - } - } - } - return tid_pair; -} - -template -__global__ void best_split_gather_classification_kernel( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, const T* __restrict__ question_ptr, - const unsigned int* __restrict__ g_nodestart, - const unsigned int* __restrict__ samplelist, const int n_nodes, - const int n_unique_labels, const int nbins, const int nrows, const int Ncols, - const int ncols_sampled, const size_t treesz, const float min_impurity_split, - SparseTreeNode* d_sparsenodes, int* d_nodelist) { - //shmemhist_parent[n_unique_labels] - extern __shared__ unsigned int shmemhist_parent[]; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - __shared__ float parent_metric; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - //shmemhist_left[n_unique_labels*nbins] - unsigned int* shmemhist_left = shmemhist_parent + n_unique_labels; - - int colstart_local = -1; - int colid; - int local_label; - unsigned int dataid; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - for (int i = threadIdx.x; i < n_unique_labels; i += blockDim.x) { - shmemhist_parent[i] = 0; - } - if (threadIdx.x == 0) { - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&shmemhist_parent[local_label], 1); - } - FDEV::execshared(shmemhist_parent, &parent_metric, count, n_unique_labels); - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int i = threadIdx.x; i < nbins * n_unique_labels; i += blockDim.x) { - shmemhist_left[i] = 0; - } - QuestionType question(question_ptr, colid, colcnt, n_nodes, blockIdx.x, - nbins); - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - T local_data = data[dataid + colid * nrows]; - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - int histid = binid * n_unique_labels + local_label; - if (local_data <= question(binid)) { - raft::myAtomicAdd(&shmemhist_left[histid], 1); - } - } - } - __syncthreads(); - GainIdxPair bin_pair = bin_info_gain_classification( - shmemhist_parent, &parent_metric, shmemhist_left, count, nbins, - n_unique_labels); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain) && (threadIdx.x == 0)) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - QuestionType question(question_ptr, colid, shmem_col, n_nodes, blockIdx.x, - nbins); - localnode.quesval = question(shmem_pair.idx); - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = - get_class_hist_shared(shmemhist_parent, n_unique_labels); - } - localnode.colid = colid; - localnode.best_metric_val = parent_metric; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//The same as above but fused minmax at block level -template -__global__ void best_split_gather_classification_minmax_kernel( - const T* __restrict__ data, const int* __restrict__ labels, - const unsigned int* __restrict__ colids, - const unsigned int* __restrict__ colstart, - const unsigned int* __restrict__ g_nodestart, - const unsigned int* __restrict__ samplelist, const int n_nodes, - const int n_unique_labels, const int nbins, const int nrows, const int Ncols, - const int ncols_sampled, const size_t treesz, const float min_impurity_split, - const T init_min_val, SparseTreeNode* d_sparsenodes, - int* d_nodelist) { - //shmemhist_parent[n_unique_labels] - extern __shared__ unsigned int shmemhist_parent[]; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - __shared__ float parent_metric; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T shmem_min, shmem_max, best_min, best_delta; - //shmemhist_left[n_unique_labels*nbins] - unsigned int* shmemhist_left = shmemhist_parent + n_unique_labels; - - int colstart_local = -1; - int colid; - int local_label; - unsigned int dataid; - T local_data; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - for (int i = threadIdx.x; i < n_unique_labels; i += blockDim.x) { - shmemhist_parent[i] = 0; - } - if (threadIdx.x == 0) { - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&shmemhist_parent[local_label], 1); - } - FDEV::execshared(shmemhist_parent, &parent_metric, count, n_unique_labels); - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (threadIdx.x == 0) { - *(E*)&shmem_min = MLCommon::Stats::encode(init_min_val); - *(E*)&shmem_max = MLCommon::Stats::encode(-init_min_val); - } - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int i = threadIdx.x; i < nbins * n_unique_labels; i += blockDim.x) { - shmemhist_left[i] = 0; - } - __syncthreads(); - - //compute min/max using independent data pass - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - unsigned int dataid = samplelist[nodestart + tid]; - local_data = data[dataid + colid * nrows]; - MLCommon::Stats::atomicMinBits(&shmem_min, local_data); - MLCommon::Stats::atomicMaxBits(&shmem_max, local_data); - } - __syncthreads(); - - T threadmin = MLCommon::Stats::decode(*(E*)&shmem_min); - T delta = - (MLCommon::Stats::decode(*(E*)&shmem_max) - threadmin) / (nbins + 1); - - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = get_data(data, local_data, dataid + colid * nrows, count); - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - int histid = binid * n_unique_labels + local_label; - if (local_data <= threadmin + delta * (binid + 1)) { - raft::myAtomicAdd(&shmemhist_left[histid], 1); - } - } - } - __syncthreads(); - GainIdxPair bin_pair = bin_info_gain_classification( - shmemhist_parent, &parent_metric, shmemhist_left, count, nbins, - n_unique_labels); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain)) { - if (threadIdx.x == 0) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - best_min = threadmin; - best_delta = delta; - } - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - localnode.quesval = best_min + (shmem_pair.idx + 1) * best_delta; - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = - get_class_hist_shared(shmemhist_parent, n_unique_labels); - } - localnode.colid = colid; - localnode.best_metric_val = parent_metric; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//A light weight implementation of the above kernel for last level, -// when all nodes are to be leafed out -template -__global__ void make_leaf_gather_classification_kernel( - const int* __restrict__ labels, const unsigned int* __restrict__ g_nodestart, - const unsigned int* __restrict__ samplelist, const int n_unique_labels, - SparseTreeNode* d_sparsenodes, int* d_nodelist) { - __shared__ float parent_metric; - //shmemhist_parent[n_unique_labels] - extern __shared__ unsigned int shmemhist_parent[]; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - - //Compute parent histograms - for (int i = threadIdx.x; i < n_unique_labels; i += blockDim.x) { - shmemhist_parent[i] = 0; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - unsigned int dataid = samplelist[nodestart + tid]; - int local_label = labels[dataid]; - raft::myAtomicAdd(&shmemhist_parent[local_label], 1); - } - FDEV::execshared(shmemhist_parent, &parent_metric, count, n_unique_labels); - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - localnode.prediction = - get_class_hist_shared(shmemhist_parent, n_unique_labels); - localnode.colid = -1; - localnode.best_metric_val = parent_metric; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/levelkernel_regressor.cuh b/cpp/src/decisiontree/levelalgo/levelkernel_regressor.cuh deleted file mode 100644 index 7e863791dc..0000000000 --- a/cpp/src/decisiontree/levelalgo/levelkernel_regressor.cuh +++ /dev/null @@ -1,1056 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include "common_kernel.cuh" - -namespace ML { -namespace DecisionTree { - -template -struct MSEGain { - static HDI T exec(const T parent_best_metric, const unsigned int total, - const unsigned int left, const unsigned int right, - const T parent_mean, T &mean_left, T &mean_right, - T &mse_left, T &mse_right) { - mean_right /= right; - mean_left /= left; - mse_left = mean_left; - mse_right = mean_right; - T left_impurity = ((float)left / total) * mean_left * mean_left; - T right_impurity = ((float)right / total) * mean_right * mean_right; - T temp = left_impurity + right_impurity - (parent_mean * parent_mean); - return temp; - } - static HDI T exec(const unsigned int total, const unsigned int left, - const unsigned int right, const T parent_mean, T &mean_left, - T &mean_right) { - mean_right /= right; - mean_left /= left; - T left_impurity = ((float)left / total) * mean_left * mean_left; - T right_impurity = ((float)right / total) * mean_right * mean_right; - T temp = left_impurity + right_impurity - (parent_mean * parent_mean); - return temp; - } -}; - -template -struct MAEGain { - static HDI T exec(const T parent_best_metric, const unsigned int total, - const unsigned int left, const unsigned int right, - const T parent_mean, T &mean_left, T &mean_right, - T &mae_left, T &mae_right) { - mean_left /= left; - mean_right /= right; - mae_left /= left; - mae_right /= right; - T left_impurity = (left * 1.0 / total) * mae_left; - T right_impurity = (right * 1.0 / total) * mae_right; - return (parent_best_metric - (left_impurity + right_impurity)); - } - static HDI T exec(const T parent_mae, const T mae_left, const T mae_right, - const unsigned int left, const unsigned int right, - const unsigned int total) { - T left_impurity = (left * 1.0 / total) * mae_left; - T right_impurity = (right * 1.0 / total) * mae_right; - return (parent_mae - (left_impurity + right_impurity)); - } -}; - -template -__global__ void pred_kernel_level(const T *__restrict__ labels, - const unsigned int *__restrict__ sample_cnt, - const int nrows, T *predout, - unsigned int *countout) { - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ T shmempred; - __shared__ unsigned int shmemcnt; - if (threadIdx.x == 0) { - shmempred = 0; - shmemcnt = 0; - } - __syncthreads(); - - for (int tid = threadid; tid < nrows; tid += blockDim.x * gridDim.x) { - T label = labels[tid]; - unsigned int count = sample_cnt[tid]; - raft::myAtomicAdd(&shmemcnt, count); - raft::myAtomicAdd(&shmempred, label * count); - } - __syncthreads(); - - if (threadIdx.x == 0) { - raft::myAtomicAdd(predout, shmempred); - raft::myAtomicAdd(countout, shmemcnt); - } - return; -} - -template -__global__ void mse_kernel_level(const T *__restrict__ labels, - const unsigned int *__restrict__ sample_cnt, - const int nrows, const T *predout, - const unsigned int *count, T *mseout) { - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ T shmemmse; - if (threadIdx.x == 0) shmemmse = 0; - __syncthreads(); - - T mean = predout[0] / count[0]; - for (int tid = threadid; tid < nrows; tid += blockDim.x * gridDim.x) { - T label = labels[tid]; - unsigned int local_count = sample_cnt[tid]; - T value = F::exec(label - mean); - raft::myAtomicAdd(&shmemmse, local_count * value); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - raft::myAtomicAdd(mseout, shmemmse); - } - return; -} -//This kernel computes predictions and count for all colls, all bins and all nodes at a given level -template -__global__ void get_pred_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ flags, - const unsigned int *__restrict__ sample_cnt, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const T *__restrict__ question_ptr, T *predout, unsigned int *countout) { - extern __shared__ char shmem_pred_kernel[]; - T *shmempred = (T *)shmem_pred_kernel; - unsigned int *shmemcount = - (unsigned int *)(&shmem_pred_kernel[nbins * n_nodes * sizeof(T)]); - unsigned int local_flag = LEAF; - T local_label; - int local_cnt; - int colstart_local = -1; - int tid = threadIdx.x + blockIdx.x * blockDim.x; - unsigned int colid; - if (tid < nrows) { - local_flag = flags[tid]; - } - if (local_flag != LEAF) { - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - } - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (local_flag != LEAF) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - colcnt, local_flag); - } - for (unsigned int i = threadIdx.x; i < nbins * n_nodes; i += blockDim.x) { - shmempred[i] = (T)0; - shmemcount[i] = 0; - } - __syncthreads(); - - //Check if leaf - if (local_flag != LEAF) { - T local_data = data[tid + colid * nrows]; - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); - -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - unsigned int nodeoff = local_flag * nbins; - raft::myAtomicAdd(&shmempred[nodeoff + binid], - local_label * local_cnt); - raft::myAtomicAdd(&shmemcount[nodeoff + binid], - local_cnt); - } - } - } - - __syncthreads(); - for (unsigned int i = threadIdx.x; i < nbins * n_nodes; i += blockDim.x) { - unsigned int offset = colcnt * nbins * n_nodes; - raft::myAtomicAdd(&predout[offset + i], shmempred[i]); - raft::myAtomicAdd(&countout[offset + i], shmemcount[i]); - } - __syncthreads(); - } -} - -//This kernel computes mse/mae for all colls, all bins and all nodes at a given level -template -__global__ void get_mse_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ flags, - const unsigned int *__restrict__ sample_cnt, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const T *__restrict__ question_ptr, const T *__restrict__ parentpred, - const unsigned int *__restrict__ parentcount, const T *__restrict__ predout, - const unsigned int *__restrict__ countout, T *mseout) { - extern __shared__ char shmem_mse_kernel[]; - T *shmem_predout = (T *)(shmem_mse_kernel); - T *shmem_mse = (T *)(shmem_mse_kernel + n_nodes * nbins * sizeof(T)); - unsigned int *shmem_countout = - (unsigned int *)(shmem_mse_kernel + 3 * n_nodes * nbins * sizeof(T)); - - unsigned int local_flag = LEAF; - T local_label; - int local_cnt; - int tid = threadIdx.x + blockIdx.x * blockDim.x; - T parent_pred; - unsigned int parent_count; - unsigned int colid; - int colstart_local = -1; - if (tid < nrows) { - local_flag = flags[tid]; - } - - if (local_flag != LEAF) { - parent_count = parentcount[local_flag]; - parent_pred = parentpred[local_flag]; - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - } - - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (local_flag != LEAF) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - colcnt, local_flag); - } - unsigned int coloff = colcnt * nbins * n_nodes; - for (unsigned int i = threadIdx.x; i < nbins * n_nodes; i += blockDim.x) { - shmem_predout[i] = predout[i + coloff]; - shmem_countout[i] = countout[i + coloff]; - } - - for (unsigned int i = threadIdx.x; i < 2 * nbins * n_nodes; - i += blockDim.x) { - shmem_mse[i] = (T)0; - } - __syncthreads(); - - //Check if leaf - if (local_flag != LEAF) { - T local_data = data[tid + colid * nrows]; - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); - -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - unsigned int nodeoff = local_flag * nbins; - T local_pred = shmem_predout[nodeoff + binid]; - unsigned int local_count = shmem_countout[nodeoff + binid]; - if (local_data <= question(binid)) { - T leftmean = local_pred / local_count; - raft::myAtomicAdd(&shmem_mse[2 * (nodeoff + binid)], - local_cnt * F::exec(local_label - leftmean)); - } else { - T rightmean = parent_pred * parent_count - local_pred; - rightmean = rightmean / (parent_count - local_count); - raft::myAtomicAdd(&shmem_mse[2 * (nodeoff + binid) + 1], - local_cnt * F::exec(local_label - rightmean)); - } - } - } - - __syncthreads(); - for (unsigned int i = threadIdx.x; i < 2 * nbins * n_nodes; - i += blockDim.x) { - raft::myAtomicAdd(&mseout[2 * coloff + i], shmem_mse[i]); - } - __syncthreads(); - } -} - -//This kernel computes predictions and count for all colls, all bins and all nodes at a given level -//This is when nodes dont fit anymore in shared memory. -template -__global__ void get_pred_kernel_global( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ flags, - const unsigned int *__restrict__ sample_cnt, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const T *__restrict__ question_ptr, T *predout, unsigned int *countout) { - unsigned int local_flag = LEAF; - T local_label; - int local_cnt; - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - for (int tid = threadid; tid < nrows; tid += blockDim.x * gridDim.x) { - local_flag = flags[tid]; - //Check if leaf - if (local_flag != LEAF) { - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - int colstart_local = -1; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - unsigned int colid = get_column_id(colids, colstart_local, Ncols, - ncols_sampled, colcnt, local_flag); - unsigned int coloffset = colcnt * nbins * n_nodes; - T local_data = data[tid + colid * nrows]; - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); - -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - unsigned int nodeoff = local_flag * nbins; - raft::myAtomicAdd(&predout[coloffset + nodeoff + binid], - local_label * local_cnt); - raft::myAtomicAdd( - &countout[coloffset + nodeoff + binid], local_cnt); - } - } - } - } - } -} - -//This kernel computes mse/mae for all colls, all bins and all nodes at a given level -// This is when nodes dont fit in shared memory -template -__global__ void get_mse_kernel_global( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ flags, - const unsigned int *__restrict__ sample_cnt, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const int nrows, const int Ncols, - const int ncols_sampled, const int nbins, const int n_nodes, - const T *__restrict__ question_ptr, const T *__restrict__ parentpred, - const unsigned int *__restrict__ parentcount, const T *__restrict__ predout, - const unsigned int *__restrict__ countout, T *mseout) { - unsigned int local_flag = LEAF; - T local_label; - int local_cnt; - int threadid = threadIdx.x + blockIdx.x * blockDim.x; - T parent_pred; - unsigned int parent_count; - - for (int tid = threadid; tid < nrows; tid += gridDim.x * blockDim.x) { - local_flag = flags[tid]; - if (local_flag != LEAF) { - local_label = labels[tid]; - local_cnt = sample_cnt[tid]; - parent_count = parentcount[local_flag]; - parent_pred = parentpred[local_flag]; - int colstart_local = -1; - if (colstart != nullptr) colstart_local = colstart[local_flag]; - - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - unsigned int colid = get_column_id(colids, colstart_local, Ncols, - ncols_sampled, colcnt, local_flag); - unsigned int coloff = colcnt * nbins * n_nodes; - T local_data = data[tid + colid * nrows]; - QuestionType question(question_ptr, colid, colcnt, n_nodes, local_flag, - nbins); - -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - unsigned int nodeoff = local_flag * nbins; - T local_pred = predout[coloff + nodeoff + binid]; - unsigned int local_count = countout[coloff + nodeoff + binid]; - if (local_data <= question(binid)) { - T leftmean = local_pred / local_count; - raft::myAtomicAdd(&mseout[2 * (coloff + nodeoff + binid)], - local_cnt * F::exec(local_label - leftmean)); - } else { - T rightmean = parent_pred * parent_count - local_pred; - rightmean = rightmean / (parent_count - local_count); - raft::myAtomicAdd(&mseout[2 * (coloff + nodeoff + binid) + 1], - local_cnt * F::exec(local_label - rightmean)); - } - } - } - } - } -} - -//This is device version of best split in case, used when more than 512 nodes. -template -__global__ void get_best_split_regression_kernel( - const T *__restrict__ mseout, const T *__restrict__ predout, - const unsigned int *__restrict__ count, const T *__restrict__ parentmean, - const unsigned int *__restrict__ parentcount, - const T *__restrict__ parentmetric, const int nbins, const int ncols_sampled, - const int n_nodes, const int min_rpn, float *outgain, int *best_col_id, - int *best_bin_id, T *child_mean, unsigned int *child_count, - T *child_best_metric) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - for (unsigned int nodeid = blockIdx.x; nodeid < n_nodes; - nodeid += gridDim.x) { - T parent_mean = parentmean[nodeid]; - unsigned int parent_count = parentcount[nodeid]; - T parent_metric = parentmetric[nodeid]; - int nodeoffset = nodeid * nbins; - GainIdxPair tid_pair; - tid_pair.gain = 0.0; - tid_pair.idx = -1; - for (int id = threadIdx.x; id < nbins * ncols_sampled; id += blockDim.x) { - int coloffset = ((int)(id / nbins)) * nbins * n_nodes; - int binoffset = id % nbins; - int threadoffset = coloffset + binoffset + nodeoffset; - unsigned int tmp_lnrows = count[threadoffset]; - unsigned int tmp_rnrows = parent_count - tmp_lnrows; - unsigned int totalrows = tmp_lnrows + tmp_rnrows; - if (tmp_lnrows == 0 || tmp_rnrows == 0 || totalrows < min_rpn) continue; - T tmp_meanleft = predout[threadoffset]; - T tmp_meanright = parent_mean * parent_count - tmp_meanleft; - T tmp_mse_left = mseout[2 * threadoffset]; - T tmp_mse_right = mseout[2 * threadoffset + 1]; - - float info_gain = (float)Gain::exec( - parent_metric, parent_count, tmp_lnrows, tmp_rnrows, parent_mean, - tmp_meanleft, tmp_meanright, tmp_mse_left, tmp_mse_right); - - if (info_gain > tid_pair.gain) { - tid_pair.gain = info_gain; - tid_pair.idx = id; - } - } - __syncthreads(); - GainIdxPair ans = - BlockReduce(temp_storage).Reduce(tid_pair, ReducePair()); - - if (threadIdx.x == 0 && ans.idx != -1) { - outgain[nodeid] = ans.gain; - best_col_id[nodeid] = (int)(ans.idx / nbins); - best_bin_id[nodeid] = ans.idx % nbins; - int coloffset = ((int)(ans.idx / nbins)) * nbins * n_nodes; - int binoffset = ans.idx % nbins; - int threadoffset = coloffset + binoffset + nodeoffset; - if (ans.idx != -1) { - unsigned int tmp_lnrows = count[threadoffset]; - child_count[2 * nodeid] = tmp_lnrows; - unsigned int tmp_rnrows = parent_count - tmp_lnrows; - child_count[2 * nodeid + 1] = tmp_rnrows; - T tmp_meanleft = predout[threadoffset]; - child_mean[2 * nodeid] = tmp_meanleft / tmp_lnrows; - child_mean[2 * nodeid + 1] = - (parent_mean * parent_count - tmp_meanleft) / tmp_rnrows; - child_best_metric[2 * nodeid] = mseout[2 * threadoffset] / tmp_lnrows; - child_best_metric[2 * nodeid + 1] = - mseout[2 * threadoffset + 1] / tmp_rnrows; - } - } - } -} - -//This is the best bin finder at block level for each column using one pass MSE -template -DI GainIdxPair bin_info_gain_regression_mse(const T sum_parent, - const T *sum_left, - const unsigned int *count_right, - const int count, const int nbins) { - GainIdxPair tid_pair; - tid_pair.gain = 0.0; - tid_pair.idx = -1; - for (int tid = threadIdx.x; tid < nbins; tid += blockDim.x) { - unsigned int right = count_right[tid]; - unsigned int left = count - right; - if ((right != 0) && (left != 0)) { - T mean_left = sum_left[tid]; - T mean_right = sum_parent - mean_left; - T mean_parent = sum_parent / count; - float info_gain = (float)MSEGain::exec(count, left, right, mean_parent, - mean_left, mean_right); - if (info_gain > tid_pair.gain) { - tid_pair.gain = info_gain; - tid_pair.idx = tid; - } - } - } - return tid_pair; -} - -//This is the best bin finder at block level for each column using two pass MAE -template -DI GainIdxPair bin_info_gain_regression_mae(const T mae_sum_parent, - const T *mae_sum_left, - const T *mae_sum_right, - const unsigned int *count_right, - const int count, const int nbins) { - GainIdxPair tid_pair; - tid_pair.gain = 0.0; - tid_pair.idx = -1; - for (int tid = threadIdx.x; tid < nbins; tid += blockDim.x) { - unsigned int right = count_right[tid]; - unsigned int left = count - right; - if ((right != 0) && (left != 0)) { - T mae_left = mae_sum_left[tid] / (T)left; - T mae_right = mae_sum_right[tid] / (T)right; - T mae_parent = mae_sum_parent / (T)count; - float info_gain = (float)MAEGain::exec(mae_parent, mae_left, mae_right, - left, right, count); - if (info_gain > tid_pair.gain) { - tid_pair.gain = info_gain; - tid_pair.idx = tid; - } - } - } - return tid_pair; -} - -//One pass best split using MSE -template -__global__ void best_split_gather_regression_mse_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const T *__restrict__ question_ptr, - const unsigned int *__restrict__ g_nodestart, - const unsigned int *__restrict__ samplelist, const int n_nodes, - const int nbins, const int nrows, const int Ncols, const int ncols_sampled, - const size_t treesz, const float min_impurity_split, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - //shmemhist_parent[n_unique_labels] - extern __shared__ char shmem_mse_gather[]; - T *shmean_left = (T *)(shmem_mse_gather); - unsigned int *shcount_right = (unsigned int *)(shmean_left + nbins); - __shared__ T mean_parent; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int colstart_local = -1; - int colid; - T local_label; - unsigned int dataid; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - if (threadIdx.x == 0) { - mean_parent = 0.0; - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&mean_parent, local_label); - } - - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int tid = threadIdx.x; tid < 2 * nbins; tid += blockDim.x) { - if (tid < nbins) - shmean_left[tid] = (T)0.0; - else - shcount_right[tid - nbins] = 0; - } - QuestionType question(question_ptr, colid, colcnt, n_nodes, blockIdx.x, - nbins); - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - T local_data = data[dataid + colid * nrows]; - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - raft::myAtomicAdd(&shmean_left[binid], local_label); - } else { - raft::myAtomicAdd(&shcount_right[binid], 1); - } - } - } - __syncthreads(); - GainIdxPair bin_pair = bin_info_gain_regression_mse( - mean_parent, shmean_left, shcount_right, count, nbins); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain) && (threadIdx.x == 0)) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - QuestionType question(question_ptr, colid, shmem_col, n_nodes, blockIdx.x, - nbins); - localnode.quesval = question(shmem_pair.idx); - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = mean_parent / count; - } - localnode.colid = colid; - localnode.best_metric_val = mean_parent / count; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//Same as above but fused with minmax mode, one pass min/max -// one pass MSE. total two pass. -template -__global__ void best_split_gather_regression_mse_minmax_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, - const unsigned int *__restrict__ g_nodestart, - const unsigned int *__restrict__ samplelist, const int n_nodes, - const int nbins, const int nrows, const int Ncols, const int ncols_sampled, - const size_t treesz, const float min_impurity_split, const T init_min_val, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - //shmemhist_parent[n_unique_labels] - extern __shared__ char shmem_mse_minmax_gather[]; - T *shmean_left = (T *)(shmem_mse_minmax_gather); - unsigned int *shcount_right = (unsigned int *)(shmean_left + nbins); - __shared__ T mean_parent; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T shmem_min, shmem_max, best_min, best_delta; - - int colstart_local = -1; - int colid; - T local_label; - unsigned int dataid; - T local_data; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - if (threadIdx.x == 0) { - mean_parent = 0.0; - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&mean_parent, local_label); - } - - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (threadIdx.x == 0) { - *(E *)&shmem_min = MLCommon::Stats::encode(init_min_val); - *(E *)&shmem_max = MLCommon::Stats::encode(-init_min_val); - } - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int tid = threadIdx.x; tid < 2 * nbins; tid += blockDim.x) { - if (tid < nbins) - shmean_left[tid] = (T)0.0; - else - shcount_right[tid - nbins] = 0; - } - __syncthreads(); - - //Compuet minmax oon independent data pass - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - unsigned int dataid = samplelist[nodestart + tid]; - local_data = data[dataid + colid * nrows]; - MLCommon::Stats::atomicMinBits(&shmem_min, local_data); - MLCommon::Stats::atomicMaxBits(&shmem_max, local_data); - } - __syncthreads(); - - T threadmin = MLCommon::Stats::decode(*(E *)&shmem_min); - T delta = - (MLCommon::Stats::decode(*(E *)&shmem_max) - threadmin) / (nbins + 1); - - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = get_data(data, local_data, dataid + colid * nrows, count); - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= threadmin + delta * (binid + 1)) { - raft::myAtomicAdd(&shmean_left[binid], local_label); - } else { - raft::myAtomicAdd(&shcount_right[binid], 1); - } - } - } - __syncthreads(); - GainIdxPair bin_pair = bin_info_gain_regression_mse( - mean_parent, shmean_left, shcount_right, count, nbins); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain) && (threadIdx.x == 0)) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - best_min = threadmin; - best_delta = delta; - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - localnode.quesval = best_min + (shmem_pair.idx + 1) * best_delta; - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = mean_parent / count; - } - localnode.colid = colid; - localnode.best_metric_val = mean_parent / count; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//A light weight implementation of the best split kernel for last level, -// when all nodes are to be leafed out. works for all algo all split criteria -template -__global__ void make_leaf_gather_regression_kernel( - const T *__restrict__ labels, const unsigned int *__restrict__ g_nodestart, - const unsigned int *__restrict__ samplelist, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - __shared__ T mean_parent; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - - //Compute parent histograms - mean_parent = 0.0f; - __syncthreads(); - - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - unsigned int dataid = samplelist[nodestart + tid]; - T local_label = labels[dataid]; - raft::myAtomicAdd(&mean_parent, local_label); - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - localnode.prediction = mean_parent / count; - localnode.colid = -1; - localnode.best_metric_val = mean_parent / count; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//Gather kernel for MAE. We need this different as MAE needs to be multipass -// One pass for mean and one pass for MAE -template -__global__ void best_split_gather_regression_mae_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, const T *__restrict__ question_ptr, - const unsigned int *__restrict__ g_nodestart, - const unsigned int *__restrict__ samplelist, const int n_nodes, - const int nbins, const int nrows, const int Ncols, const int ncols_sampled, - const size_t treesz, const float min_impurity_split, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - extern __shared__ char shmem_mae_gather[]; - T *shmean_left = (T *)shmem_mae_gather; - T *shmae_left = (T *)(shmean_left + nbins); - T *shmae_right = (T *)(shmae_left + nbins); - unsigned int *shcount_right = (unsigned int *)(shmae_right + nbins); - __shared__ T mean_parent, mae_parent; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int colstart_local = -1; - int colid; - T local_label; - T local_data; - unsigned int dataid; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - if (threadIdx.x == 0) { - mean_parent = 0.0; - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&mean_parent, local_label); - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_label = get_label(labels, local_label, dataid, count); - T value = (mean_parent / count) - local_label; - raft::myAtomicAdd(&mae_parent, raft::myAbs(value)); - } - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int tid = threadIdx.x; tid < 2 * nbins; tid += blockDim.x) { - if (tid < nbins) { - shmean_left[tid] = (T)0; - shmae_left[tid] = (T)0; - } else { - shcount_right[tid - nbins] = 0; - shmae_right[tid - nbins] = (T)0; - } - } - QuestionType question(question_ptr, colid, colcnt, n_nodes, blockIdx.x, - nbins); - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = data[dataid + colid * nrows]; - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - raft::myAtomicAdd(&shmean_left[binid], local_label); - } else { - raft::myAtomicAdd(&shcount_right[binid], 1); - } - } - } - __syncthreads(); - //second data pass is needed for MAE - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = get_data(data, local_data, dataid + colid * nrows, count); - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= question(binid)) { - T value = - (shmean_left[binid] / (count - shcount_right[binid])) - local_label; - raft::myAtomicAdd(&shmae_left[binid], raft::myAbs(value)); - } else { - T value = - ((mean_parent - shmean_left[binid]) / shcount_right[binid]) - - local_label; - raft::myAtomicAdd(&shmae_right[binid], raft::myAbs(value)); - } - } - } - __syncthreads(); - GainIdxPair bin_pair = bin_info_gain_regression_mae( - mae_parent, shmae_left, shmae_right, shcount_right, count, nbins); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain) && (threadIdx.x == 0)) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - QuestionType question(question_ptr, colid, shmem_col, n_nodes, blockIdx.x, - nbins); - localnode.quesval = question(shmem_pair.idx); - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = mean_parent / count; - } - localnode.colid = colid; - localnode.best_metric_val = mae_parent / count; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -//Same as above but fused with minmax mode, one pass min/max -// one pass Mean. one pass MAE. total three passes. -template -__global__ void best_split_gather_regression_mae_minmax_kernel( - const T *__restrict__ data, const T *__restrict__ labels, - const unsigned int *__restrict__ colids, - const unsigned int *__restrict__ colstart, - const unsigned int *__restrict__ g_nodestart, - const unsigned int *__restrict__ samplelist, const int n_nodes, - const int nbins, const int nrows, const int Ncols, const int ncols_sampled, - const size_t treesz, const float min_impurity_split, const T init_min_val, - SparseTreeNode *d_sparsenodes, int *d_nodelist) { - //shmemhist_parent[n_unique_labels] - extern __shared__ char shmem_mae_minmax_gather[]; - T *shmean_left = (T *)shmem_mae_minmax_gather; - T *shmae_left = (T *)(shmean_left + nbins); - T *shmae_right = (T *)(shmae_left + nbins); - unsigned int *shcount_right = (unsigned int *)(shmae_right + nbins); - __shared__ T mean_parent, mae_parent; - __shared__ GainIdxPair shmem_pair; - __shared__ int shmem_col; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T shmem_min, shmem_max, best_min, best_delta; - - int colstart_local = -1; - int colid; - T local_label; - unsigned int dataid; - T local_data; - unsigned int nodestart = g_nodestart[blockIdx.x]; - unsigned int count = g_nodestart[blockIdx.x + 1] - nodestart; - if (colstart != nullptr) colstart_local = colstart[blockIdx.x]; - - //Compute parent histograms - if (threadIdx.x == 0) { - mean_parent = 0.0; - mae_parent = 0.0; - shmem_pair.gain = 0.0f; - shmem_pair.idx = -1; - shmem_col = -1; - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = samplelist[nodestart + tid]; - local_label = labels[dataid]; - raft::myAtomicAdd(&mean_parent, local_label); - } - __syncthreads(); - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_label = get_label(labels, local_label, dataid, count); - T value = (mean_parent / count) - local_label; - raft::myAtomicAdd(&mae_parent, raft::myAbs(value)); - } - - //Loop over cols - for (unsigned int colcnt = 0; colcnt < ncols_sampled; colcnt++) { - if (threadIdx.x == 0) { - *(E *)&shmem_min = MLCommon::Stats::encode(init_min_val); - *(E *)&shmem_max = MLCommon::Stats::encode(-init_min_val); - } - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, colcnt, - blockIdx.x); - for (int tid = threadIdx.x; tid < 2 * nbins; tid += blockDim.x) { - if (tid < nbins) { - shmean_left[tid] = (T)0; - shmae_left[tid] = (T)0; - } else { - shcount_right[tid - nbins] = 0; - shmae_right[tid - nbins] = (T)0; - } - } - __syncthreads(); - - //Compuet minmax on independent data pass - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - unsigned int dataid = samplelist[nodestart + tid]; - local_data = data[dataid + colid * nrows]; - MLCommon::Stats::atomicMinBits(&shmem_min, local_data); - MLCommon::Stats::atomicMaxBits(&shmem_max, local_data); - } - __syncthreads(); - - T threadmin = MLCommon::Stats::decode(*(E *)&shmem_min); - T delta = - (MLCommon::Stats::decode(*(E *)&shmem_max) - threadmin) / (nbins + 1); - - //Second pass for Mean - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = get_data(data, local_data, dataid + colid * nrows, count); - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= threadmin + delta * (binid + 1)) { - raft::myAtomicAdd(&shmean_left[binid], local_label); - } else { - raft::myAtomicAdd(&shcount_right[binid], 1); - } - } - } - __syncthreads(); - //Third pass needed for MAE - for (int tid = threadIdx.x; tid < count; tid += blockDim.x) { - dataid = get_samplelist(samplelist, dataid, nodestart, tid, count); - local_data = get_data(data, local_data, dataid + colid * nrows, count); - local_label = get_label(labels, local_label, dataid, count); -#pragma unroll(8) - for (unsigned int binid = 0; binid < nbins; binid++) { - if (local_data <= threadmin + delta * (binid + 1)) { - T value = - (shmean_left[binid] / (count - shcount_right[binid])) - local_label; - raft::myAtomicAdd(&shmae_left[binid], raft::myAbs(value)); - } else { - T value = - ((mean_parent - shmean_left[binid]) / shcount_right[binid]) - - local_label; - raft::myAtomicAdd(&shmae_right[binid], raft::myAbs(value)); - } - } - } - __syncthreads(); - - GainIdxPair bin_pair = bin_info_gain_regression_mae( - mae_parent, shmae_left, shmae_right, shcount_right, count, nbins); - GainIdxPair best_bin_pair = - BlockReduce(temp_storage).Reduce(bin_pair, ReducePair()); - __syncthreads(); - - if ((best_bin_pair.gain > shmem_pair.gain) && (threadIdx.x == 0)) { - shmem_pair = best_bin_pair; - shmem_col = colcnt; - best_min = threadmin; - best_delta = delta; - } - } - __syncthreads(); - if (threadIdx.x == 0) { - SparseTreeNode localnode; - if ((shmem_col != -1) && (shmem_pair.gain > min_impurity_split)) { - colid = get_column_id(colids, colstart_local, Ncols, ncols_sampled, - shmem_col, blockIdx.x); - localnode.quesval = best_min + (shmem_pair.idx + 1) * best_delta; - localnode.left_child_id = treesz + 2 * blockIdx.x; - } else { - colid = -1; - localnode.prediction = mean_parent / count; - } - localnode.colid = colid; - localnode.best_metric_val = mae_parent / count; - d_sparsenodes[d_nodelist[blockIdx.x]] = localnode; - } -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/metric.cuh b/cpp/src/decisiontree/levelalgo/metric.cuh deleted file mode 100644 index 0f533decf7..0000000000 --- a/cpp/src/decisiontree/levelalgo/metric.cuh +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include "metric_def.cuh" - -namespace ML { -namespace DecisionTree { - -template -DI T SquareFunctor::exec(T x) { - return raft::myPow(x, (T)2); -} - -template -DI T AbsFunctor::exec(T x) { - return raft::myAbs(x); -} - -float GiniFunctor::max_val(int nclass) { return 1.0; } - -float EntropyFunctor::max_val(int nclass) { - float prob = 1.0 / nclass; - return (-1.0 * nclass * prob * logf(prob)); -} -float GiniFunctor::exec(std::vector &hist, int nrows) { - float gval = 1.0; - for (int i = 0; i < hist.size(); i++) { - float prob = ((float)hist[i]) / nrows; - gval -= prob * prob; - } - return gval; -} - -float EntropyFunctor::exec(std::vector &hist, int nrows) { - float eval = 0.0; - for (int i = 0; i < hist.size(); i++) { - if (hist[i] != 0) { - float prob = ((float)hist[i]) / nrows; - eval += prob * logf(prob); - } - } - return (-1 * eval); -} - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/levelalgo/metric_def.cuh b/cpp/src/decisiontree/levelalgo/metric_def.cuh deleted file mode 100644 index 90a264c6a5..0000000000 --- a/cpp/src/decisiontree/levelalgo/metric_def.cuh +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include "../memory.h" - -namespace ML { -namespace DecisionTree { - -struct SquareFunctor { - template - static DI T exec(T x); -}; - -struct AbsFunctor { - template - static DI T exec(T x); -}; - -struct GiniFunctor { - static float exec(std::vector& hist, int nrows); - static float max_val(int nclass); -}; - -struct EntropyFunctor { - static float exec(std::vector& hist, int nrows); - static float max_val(int nclass); -}; - -} // namespace DecisionTree -} // namespace ML diff --git a/cpp/src/decisiontree/memory.cuh b/cpp/src/decisiontree/memory.cuh deleted file mode 100644 index 599b8c0860..0000000000 --- a/cpp/src/decisiontree/memory.cuh +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include "memory.h" - -template -TemporaryMemory::TemporaryMemory( - const std::shared_ptr device_allocator_in, - const std::shared_ptr host_allocator_in, - const cudaStream_t stream_in, int N, int Ncols, int n_unique, - const ML::DecisionTree::DecisionTreeParams& tree_params) { - stream = stream_in; - max_shared_mem = raft::getSharedMemPerBlock(); - num_sms = raft::getMultiProcessorCount(); - device_allocator = device_allocator_in; - host_allocator = host_allocator_in; - LevelMemAllocator(N, Ncols, n_unique, tree_params); -} - -template -TemporaryMemory::TemporaryMemory( - const raft::handle_t& handle, cudaStream_t stream_in, int N, int Ncols, - int n_unique, const ML::DecisionTree::DecisionTreeParams& tree_params) { - stream = stream_in; - max_shared_mem = raft::getSharedMemPerBlock(); - num_sms = raft::getMultiProcessorCount(); - device_allocator = handle.get_device_allocator(); - host_allocator = handle.get_host_allocator(); - LevelMemAllocator(N, Ncols, n_unique, tree_params); -} - -template -TemporaryMemory::~TemporaryMemory() { - LevelMemCleaner(); -} - -template -void TemporaryMemory::print_info(int depth, int nrows, int ncols, - float colper) { - size_t maxnodes = max_nodes_per_level; - size_t ncols_sampled = (size_t)(ncols * colper); - - ML::PatternSetter _("%v"); - CUML_LOG_DEBUG("maxnodes --> %lu gather maxnodes--> %lu", maxnodes, - gather_max_nodes); - CUML_LOG_DEBUG("Parent size --> %lu", parentsz); - CUML_LOG_DEBUG("Child size --> %lu", childsz); - CUML_LOG_DEBUG("Nrows size --> %d", (nrows + 1)); - CUML_LOG_DEBUG("Sparse tree holder size --> %lu", 2 * gather_max_nodes); - CUML_LOG_DEBUG(" Total temporary memory usage--> %lf MB", - ((double)totalmem / (1024 * 1024))); -} - -template -void TemporaryMemory::LevelMemAllocator( - int nrows, int ncols, int n_unique, - const ML::DecisionTree::DecisionTreeParams& tree_params) { - int nbins = tree_params.n_bins; - int depth = (tree_params.max_depth < 0) ? -1 : (tree_params.max_depth + 1); - if (depth > swap_depth || (depth == -1)) { - max_nodes_per_level = pow(2, swap_depth); - } else { - max_nodes_per_level = pow(2, depth); - } - size_t maxnodes = max_nodes_per_level; - size_t ncols_sampled = (size_t)(ncols * tree_params.max_features); - ncols_sampled = ncols_sampled > 0 ? ncols_sampled : 1; - if (depth < 64) { - gather_max_nodes = std::min((size_t)(nrows + 1), - (size_t)(pow((size_t)2, (size_t)depth) + 1)); - } else { - gather_max_nodes = nrows + 1; - } - parentsz = std::max(maxnodes, gather_max_nodes); - childsz = std::max(2 * maxnodes, 2 * gather_max_nodes); - - d_flags = - new MLCommon::device_buffer(device_allocator, stream, nrows); - h_new_node_flags = - new MLCommon::host_buffer(host_allocator, stream, maxnodes); - d_new_node_flags = new MLCommon::device_buffer( - device_allocator, stream, maxnodes); - totalmem += nrows * sizeof(int) + maxnodes * sizeof(int); - //This buffers will be renamed and reused in gather algorithms - h_split_colidx = - new MLCommon::host_buffer(host_allocator, stream, parentsz); - h_split_binidx = - new MLCommon::host_buffer(host_allocator, stream, parentsz); - d_split_colidx = - new MLCommon::device_buffer(device_allocator, stream, parentsz); - d_split_binidx = - new MLCommon::device_buffer(device_allocator, stream, parentsz); - size_t metric_size = std::max(parentsz, (size_t)(nrows + 1)); - h_parent_metric = - new MLCommon::host_buffer(host_allocator, stream, metric_size); - d_parent_metric = - new MLCommon::device_buffer(device_allocator, stream, metric_size); - h_child_best_metric = - new MLCommon::host_buffer(host_allocator, stream, childsz); - h_outgain = - new MLCommon::host_buffer(host_allocator, stream, parentsz); - d_child_best_metric = - new MLCommon::device_buffer(device_allocator, stream, childsz); - d_outgain = - new MLCommon::device_buffer(device_allocator, stream, parentsz); - //end of reusable buffers - totalmem = - 3 * parentsz * sizeof(int) + childsz * sizeof(T) + (nrows + 1) * sizeof(T); - - if (tree_params.split_algo == 0) { - d_globalminmax = new MLCommon::device_buffer( - device_allocator, stream, 2 * maxnodes * ncols_sampled); - h_globalminmax = new MLCommon::host_buffer(host_allocator, stream, - 2 * maxnodes * ncols_sampled); - totalmem += maxnodes * ncols * sizeof(T); - } else { - h_quantile = - new MLCommon::host_buffer(host_allocator, stream, nbins * ncols); - d_quantile = - new MLCommon::device_buffer(device_allocator, stream, nbins * ncols); - totalmem += nbins * ncols * sizeof(T); - } - d_sample_cnt = - new MLCommon::device_buffer(device_allocator, stream, nrows); - //This buffers are also reused by gather algorithm - d_colids = - new MLCommon::device_buffer(device_allocator, stream, ncols); - d_colstart = new MLCommon::device_buffer(device_allocator, - stream, parentsz); - h_colids = - new MLCommon::host_buffer(host_allocator, stream, ncols); - h_colstart = - new MLCommon::host_buffer(host_allocator, stream, parentsz); - totalmem += ncols * sizeof(int) + parentsz * sizeof(int); - //CUB memory for gather algorithms - size_t temp_storage_bytes = 0; - void* cub_buffer = NULL; - cub::DeviceScan::ExclusiveSum(cub_buffer, temp_storage_bytes, - d_split_colidx->data(), d_split_binidx->data(), - gather_max_nodes); - temp_cub_buffer = new MLCommon::device_buffer(device_allocator, stream, - temp_storage_bytes); - h_counter = new MLCommon::host_buffer(host_allocator, stream, 1); - d_counter = new MLCommon::device_buffer(device_allocator, stream, 1); - temp_cub_bytes = temp_storage_bytes; - totalmem += temp_cub_bytes + 1; - - //Allocate node vectors - d_sparsenodes = new MLCommon::device_buffer>( - device_allocator, stream, 2 * gather_max_nodes); - h_sparsenodes = new MLCommon::host_buffer>( - host_allocator, stream, 2 * gather_max_nodes); - totalmem += 2 * gather_max_nodes * sizeof(SparseTreeNode); - - //Regression - if (typeid(L) == typeid(T)) { - d_mseout = new MLCommon::device_buffer( - device_allocator, stream, 2 * nbins * ncols_sampled * maxnodes); - d_predout = new MLCommon::device_buffer( - device_allocator, stream, nbins * ncols_sampled * maxnodes); - d_count = new MLCommon::device_buffer( - device_allocator, stream, nbins * ncols_sampled * maxnodes); - d_parent_pred = - new MLCommon::device_buffer(device_allocator, stream, maxnodes); - d_parent_count = new MLCommon::device_buffer( - device_allocator, stream, maxnodes); - d_child_pred = - new MLCommon::device_buffer(device_allocator, stream, 2 * maxnodes); - d_child_count = new MLCommon::device_buffer( - device_allocator, stream, 2 * maxnodes); - h_mseout = new MLCommon::host_buffer( - host_allocator, stream, 2 * nbins * ncols_sampled * maxnodes); - h_predout = new MLCommon::host_buffer(host_allocator, stream, - nbins * ncols_sampled * maxnodes); - h_count = new MLCommon::host_buffer( - host_allocator, stream, nbins * ncols_sampled * maxnodes); - h_child_pred = - new MLCommon::host_buffer(host_allocator, stream, 2 * maxnodes); - h_child_count = new MLCommon::host_buffer( - host_allocator, stream, 2 * maxnodes); - - totalmem += 3 * nbins * ncols_sampled * maxnodes * sizeof(T); - totalmem += nbins * ncols_sampled * maxnodes * sizeof(unsigned int); - totalmem += 3 * maxnodes * sizeof(T); - totalmem += 3 * maxnodes * sizeof(unsigned int); - } - - //Classification - if (typeid(L) == typeid(int)) { - size_t histcount = ncols_sampled * nbins * n_unique * maxnodes; - d_histogram = new MLCommon::device_buffer(device_allocator, - stream, histcount); - h_histogram = new MLCommon::host_buffer(host_allocator, - stream, histcount); - h_parent_hist = new MLCommon::host_buffer( - host_allocator, stream, maxnodes * n_unique); - h_child_hist = new MLCommon::host_buffer( - host_allocator, stream, 2 * maxnodes * n_unique); - d_parent_hist = new MLCommon::device_buffer( - device_allocator, stream, maxnodes * n_unique); - d_child_hist = new MLCommon::device_buffer( - device_allocator, stream, 2 * maxnodes * n_unique); - totalmem += histcount * sizeof(unsigned int); - totalmem += n_unique * maxnodes * 3 * sizeof(unsigned int); - } - //Calculate Max nodes in shared memory. - if (typeid(L) == typeid(int)) { - max_nodes_class = max_shared_mem / (nbins * n_unique * sizeof(int)); - max_nodes_class /= 2; // For occupancy purposes. - } - if (typeid(L) == typeid(T)) { - size_t pernode_pred = nbins * (sizeof(T) + sizeof(unsigned int)); - max_nodes_pred = max_shared_mem / pernode_pred; - max_nodes_mse = max_shared_mem / (pernode_pred + 2 * nbins * sizeof(T)); - max_nodes_pred /= 2; // For occupancy purposes. - max_nodes_mse /= 2; // For occupancy purposes. - } - if (tree_params.split_algo == ML::SPLIT_ALGO::HIST) { - size_t shmem_per_node = 2 * sizeof(T); - max_nodes_minmax = max_shared_mem / shmem_per_node; - max_nodes_minmax /= 2; - } -} - -template -void TemporaryMemory::LevelMemCleaner() { - h_new_node_flags->release(stream); - d_new_node_flags->release(stream); - h_split_colidx->release(stream); - d_split_colidx->release(stream); - h_split_binidx->release(stream); - d_split_binidx->release(stream); - h_parent_metric->release(stream); - h_child_best_metric->release(stream); - h_outgain->release(stream); - d_parent_metric->release(stream); - d_child_best_metric->release(stream); - d_outgain->release(stream); - d_flags->release(stream); - if (h_quantile != nullptr) h_quantile->release(stream); - if (d_quantile != nullptr) d_quantile->release(stream); - if (d_globalminmax != nullptr) d_globalminmax->release(stream); - if (h_globalminmax != nullptr) h_globalminmax->release(stream); - d_sample_cnt->release(stream); - d_colids->release(stream); - if (d_colstart != nullptr) d_colstart->release(stream); - h_colids->release(stream); - if (h_colstart != nullptr) h_colstart->release(stream); - delete h_new_node_flags; - delete d_new_node_flags; - delete h_split_colidx; - delete d_split_colidx; - delete h_split_binidx; - delete d_split_binidx; - delete h_parent_metric; - delete h_child_best_metric; - delete h_outgain; - delete d_parent_metric; - delete d_child_best_metric; - delete d_outgain; - delete d_flags; - if (h_quantile != nullptr) delete h_quantile; - if (d_quantile != nullptr) delete d_quantile; - if (d_globalminmax != nullptr) delete d_globalminmax; - if (h_globalminmax != nullptr) delete h_globalminmax; - delete d_sample_cnt; - delete d_colids; - delete h_colids; - if (d_colstart != nullptr) delete d_colstart; - if (h_colstart != nullptr) delete h_colstart; - temp_cub_buffer->release(stream); - delete temp_cub_buffer; - h_counter->release(stream); - d_counter->release(stream); - delete h_counter; - delete d_counter; - d_sparsenodes->release(stream); - h_sparsenodes->release(stream); - delete d_sparsenodes; - delete h_sparsenodes; - //Classification - if (typeid(L) == typeid(int)) { - h_histogram->release(stream); - d_histogram->release(stream); - h_parent_hist->release(stream); - h_child_hist->release(stream); - d_parent_hist->release(stream); - d_child_hist->release(stream); - delete d_histogram; - delete h_histogram; - delete h_parent_hist; - delete h_child_hist; - delete d_parent_hist; - delete d_child_hist; - } - //Regression - if (typeid(L) == typeid(T)) { - d_parent_pred->release(stream); - d_parent_count->release(stream); - d_mseout->release(stream); - d_predout->release(stream); - d_count->release(stream); - d_child_pred->release(stream); - d_child_count->release(stream); - - h_child_pred->release(stream); - h_child_count->release(stream); - h_mseout->release(stream); - h_predout->release(stream); - h_count->release(stream); - - delete d_child_pred; - delete d_child_count; - delete d_parent_pred; - delete d_parent_count; - delete d_mseout; - delete d_predout; - delete d_count; - - delete h_mseout; - delete h_predout; - delete h_count; - delete h_child_pred; - delete h_child_count; - } -} diff --git a/cpp/src/decisiontree/memory.h b/cpp/src/decisiontree/memory.h deleted file mode 100644 index c20219c4ab..0000000000 --- a/cpp/src/decisiontree/memory.h +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include - -template -struct TemporaryMemory { - //depth algorithm changer - const int swap_depth = 14; - static const int gather_threads = 256; - size_t parentsz, childsz, gather_max_nodes; - //Allocators parsed from CUML handle - std::shared_ptr device_allocator; - std::shared_ptr host_allocator; - - //Tree holder for gather algorithm - MLCommon::device_buffer> *d_sparsenodes = nullptr; - MLCommon::host_buffer> *h_sparsenodes = nullptr; - - //Temporary data buffer - MLCommon::device_buffer *temp_data = nullptr; - //Temporary CUB buffer - MLCommon::device_buffer *temp_cub_buffer = nullptr; - size_t temp_cub_bytes; - MLCommon::host_buffer *h_counter = nullptr; - MLCommon::device_buffer *d_counter = nullptr; - //Host/Device histograms and device minmaxs - MLCommon::device_buffer *d_globalminmax = nullptr; - MLCommon::host_buffer *h_globalminmax = nullptr; - MLCommon::device_buffer *d_mseout = nullptr; - MLCommon::device_buffer *d_predout = nullptr; - MLCommon::host_buffer *h_mseout = nullptr; - MLCommon::host_buffer *h_predout = nullptr; - //Total temp mem - size_t totalmem = 0; - - //CUDA stream - cudaStream_t stream; - - //No of SMs - int num_sms; - - //Maximum shared memory in GPU - size_t max_shared_mem; - - //For quantiles and colids; this part is common - MLCommon::device_buffer *d_quantile = nullptr; - MLCommon::host_buffer *h_quantile = nullptr; - MLCommon::device_buffer *d_colids = nullptr; - MLCommon::device_buffer *d_colstart = nullptr; - MLCommon::host_buffer *h_colids = nullptr; - MLCommon::host_buffer *h_colstart = nullptr; - - //For level algorithm - MLCommon::device_buffer *d_flags = nullptr; - MLCommon::device_buffer *d_histogram = nullptr; - MLCommon::host_buffer *h_histogram = nullptr; - MLCommon::host_buffer *h_split_colidx = nullptr; - MLCommon::host_buffer *h_split_binidx = nullptr; - MLCommon::device_buffer *d_split_colidx = nullptr; - MLCommon::device_buffer *d_split_binidx = nullptr; - MLCommon::host_buffer *h_new_node_flags = nullptr; - MLCommon::device_buffer *d_new_node_flags = nullptr; - MLCommon::host_buffer *h_parent_hist = nullptr; - MLCommon::host_buffer *h_child_hist = nullptr; - MLCommon::device_buffer *d_parent_hist = nullptr; - MLCommon::device_buffer *d_child_hist = nullptr; - MLCommon::host_buffer *h_parent_metric = nullptr; - MLCommon::host_buffer *h_child_best_metric = nullptr; - MLCommon::host_buffer *h_outgain = nullptr; - MLCommon::device_buffer *d_outgain = nullptr; - MLCommon::device_buffer *d_parent_metric = nullptr; - MLCommon::device_buffer *d_child_best_metric = nullptr; - MLCommon::device_buffer *d_sample_cnt = nullptr; - - MLCommon::device_buffer *d_parent_pred = nullptr; - MLCommon::device_buffer *d_parent_count = nullptr; - MLCommon::device_buffer *d_child_pred = nullptr; - MLCommon::device_buffer *d_child_count = nullptr; - MLCommon::device_buffer *d_count = nullptr; - MLCommon::host_buffer *h_count = nullptr; - MLCommon::host_buffer *h_child_pred = nullptr; - MLCommon::host_buffer *h_child_count = nullptr; - - int max_nodes_class = 0; - int max_nodes_pred = 0; - int max_nodes_mse = 0; - int max_nodes_per_level = 0; - int max_nodes_minmax = 0; - TemporaryMemory( - const std::shared_ptr device_allocator_in, - const std::shared_ptr host_allocator_in, - const cudaStream_t stream_in, int N, int Ncols, int n_unique, - const ML::DecisionTree::DecisionTreeParams &tree_params); - - TemporaryMemory(const raft::handle_t &handle, cudaStream_t stream_in, int N, - int Ncols, int n_unique, - const ML::DecisionTree::DecisionTreeParams &tree_params); - - ~TemporaryMemory(); - - void LevelMemAllocator( - int nrows, int ncols, int n_unique, - const ML::DecisionTree::DecisionTreeParams &tree_params); - - void LevelMemCleaner(); - - void print_info(int depth, int nrows, int ncols, float colper); -}; -#include "memory.cuh" diff --git a/cpp/src/decisiontree/quantile/quantile.cuh b/cpp/src/decisiontree/quantile/quantile.cuh index 556c90d176..6cc2fa8d91 100644 --- a/cpp/src/decisiontree/quantile/quantile.cuh +++ b/cpp/src/decisiontree/quantile/quantile.cuh @@ -30,165 +30,6 @@ using device_allocator = raft::mr::device::allocator; template using device_buffer = raft::mr::device::buffer; -template -__global__ void allcolsampler_kernel(const T *__restrict__ data, - const unsigned int *__restrict__ rowids, - const unsigned int *__restrict__ colids, - const int nrows, const int ncols, - const int rowoffset, T *sampledcols) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - - for (unsigned int i = tid; i < nrows * ncols; i += blockDim.x * gridDim.x) { - int newcolid = (int)(i / nrows); - int myrowstart; - if (colids != nullptr) { - myrowstart = colids[newcolid] * rowoffset; - } else { - myrowstart = newcolid * rowoffset; - } - - int index; - if (rowids != nullptr) { - index = rowids[i % nrows] + myrowstart; - } else { - index = i % nrows + myrowstart; - } - sampledcols[i] = data[index]; - } - return; -} - -__global__ void set_sorting_offset(const int nrows, const int ncols, - int *offsets) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid <= ncols) offsets[tid] = tid * nrows; - - return; -} - -template -__global__ void get_all_quantiles(const T *__restrict__ data, T *quantile, - const int nrows, const int ncols, - const int nbins) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < nbins * ncols) { - int binoff = (int)(nrows / nbins); - int coloff = (int)(tid / nbins) * nrows; - quantile[tid] = data[((tid % nbins) + 1) * binoff - 1 + coloff]; - } - return; -} - -template -void preprocess_quantile(const T *data, const unsigned int *rowids, - const int n_sampled_rows, const int ncols, - const int rowoffset, const int nbins, - std::shared_ptr> tempmem) { - /* - // Dynamically determine batch_cols (number of columns processed per loop iteration) from the available device memory. - size_t free_mem, total_mem; - CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); - int max_ncols = free_mem / (2 * n_sampled_rows * sizeof(T)); - int batch_cols = (max_ncols > ncols) ? ncols : max_ncols; - ASSERT(max_ncols != 0, "Cannot preprocess quantiles due to insufficient device memory."); - */ - - ML::PUSH_RANGE("preprocessing quantile @quantile.cuh"); - int batch_cols = - 1; // Processing one column at a time, for now, until an appropriate getMemInfo function is provided for the raft::mr::device::allocator interface. - - int threads = 128; - MLCommon::device_buffer *d_offsets; - MLCommon::device_buffer *d_keys_out; - const T *d_keys_in; - int blocks; - if (tempmem->temp_data != nullptr) { - T *d_keys_out = tempmem->temp_data->data(); - unsigned int *colids = nullptr; - blocks = raft::ceildiv(ncols * n_sampled_rows, threads); - allcolsampler_kernel<<stream>>>( - data, rowids, colids, n_sampled_rows, ncols, rowoffset, - d_keys_out); // d_keys_in already allocated for all ncols - CUDA_CHECK(cudaGetLastError()); - d_keys_in = d_keys_out; - } else { - d_keys_in = data; - } - - d_offsets = new MLCommon::device_buffer(tempmem->device_allocator, - tempmem->stream, batch_cols + 1); - - blocks = raft::ceildiv(batch_cols + 1, threads); - ML::PUSH_RANGE("set_sorting_offset kernel @quantile.cuh"); - set_sorting_offset<<stream>>>( - n_sampled_rows, batch_cols, d_offsets->data()); - ML::POP_RANGE(); - CUDA_CHECK(cudaGetLastError()); - - // Determine temporary device storage requirements - MLCommon::device_buffer *d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - - int batch_cnt = - raft::ceildiv(ncols, batch_cols); // number of loop iterations - int last_batch_size = - ncols - batch_cols * (batch_cnt - 1); // number of columns in last batch - int batch_items = - n_sampled_rows * batch_cols; // used to determine d_temp_storage size - - d_keys_out = new MLCommon::device_buffer(tempmem->device_allocator, - tempmem->stream, batch_items); - ML::PUSH_RANGE( - "DecisionTree::cub::DeviceRadixSort::SortKeys over batch_items " - "@quantile.cuh"); - CUDA_CHECK(cub::DeviceRadixSort::SortKeys( - d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out->data(), - batch_items, 0, 8 * sizeof(T), tempmem->stream)); - ML::POP_RANGE(); - // Allocate temporary storage - d_temp_storage = new MLCommon::device_buffer( - tempmem->device_allocator, tempmem->stream, temp_storage_bytes); - - ML::PUSH_RANGE("iterative quantile computation for each batch"); - // Compute quantiles for cur_batch_cols columns per loop iteration. - for (int batch = 0; batch < batch_cnt; batch++) { - int cur_batch_cols = (batch == batch_cnt - 1) - ? last_batch_size - : batch_cols; // properly handle the last batch - - int batch_offset = batch * n_sampled_rows * batch_cols; - int quantile_offset = batch * nbins * batch_cols; - ML::PUSH_RANGE("DeviceRadixSort::SortKeys"); - CUDA_CHECK(cub::DeviceRadixSort::SortKeys( - (void *)d_temp_storage->data(), temp_storage_bytes, - &d_keys_in[batch_offset], d_keys_out->data(), n_sampled_rows, 0, - 8 * sizeof(T), tempmem->stream)); - ML::POP_RANGE(); - - blocks = raft::ceildiv(cur_batch_cols * nbins, threads); - ML::PUSH_RANGE("get_all_quantiles kernel @quantile.cuh"); - get_all_quantiles<<stream>>>( - d_keys_out->data(), &tempmem->d_quantile->data()[quantile_offset], - n_sampled_rows, cur_batch_cols, nbins); - ML::POP_RANGE(); - - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaStreamSynchronize(tempmem->stream)); - } - ML::POP_RANGE(); - raft::update_host(tempmem->h_quantile->data(), tempmem->d_quantile->data(), - nbins * ncols, tempmem->stream); - d_keys_out->release(tempmem->stream); - d_offsets->release(tempmem->stream); - d_temp_storage->release(tempmem->stream); - delete d_keys_out; - delete d_offsets; - delete d_temp_storage; - ML::POP_RANGE(); - - return; -} - template __global__ void computeQuantilesSorted(T *quantiles, const int n_bins, const T *sorted_data, const int length) { diff --git a/cpp/src/decisiontree/quantile/quantile.h b/cpp/src/decisiontree/quantile/quantile.h index 2a15cf368c..9e0cb25353 100644 --- a/cpp/src/decisiontree/quantile/quantile.h +++ b/cpp/src/decisiontree/quantile/quantile.h @@ -19,18 +19,9 @@ #include #include -template -struct TemporaryMemory; - namespace ML { namespace DecisionTree { -template -void preprocess_quantile(const T *data, const unsigned int *rowids, - const int n_sampled_rows, const int ncols, - const int rowoffset, const int nbins, - std::shared_ptr> tempmem); - template void computeQuantiles( T *quantiles, int n_bins, const T *data, int n_rows, int n_cols, diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index c78bf7a9e7..366cacfb37 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -611,31 +611,15 @@ RF_metrics score(const raft::handle_t& user_handle, } RF_params set_rf_params(int max_depth, int max_leaves, float max_features, - int n_bins, int split_algo, int min_samples_leaf, - int min_samples_split, float min_impurity_decrease, - bool bootstrap_features, bool bootstrap, int n_trees, - float max_samples, uint64_t seed, - CRITERION split_criterion, bool quantile_per_tree, - int cfg_n_streams, bool use_experimental_backend, + int n_bins, int min_samples_leaf, int min_samples_split, + float min_impurity_decrease, bool bootstrap, + int n_trees, float max_samples, uint64_t seed, + CRITERION split_criterion, int cfg_n_streams, int max_batch_size) { - // give deprecation notice for use of bootstrap_features - if (bootstrap_features) { - CUML_LOG_WARN( - "Parameter 'bootstrap_features' is deprecated and will be" - " removed in 0.21 release. Please use 'max_features' instead."); - if (max_features == 1.f) { - CUML_LOG_WARN( - "Parameter conflict: 'max_features' is set to 1.0 when " - "'bootstrap_features' is enabled. " - "'max_features' will be used to override 'bootstrap_features'."); - } - } DecisionTree::DecisionTreeParams tree_params; DecisionTree::set_tree_params( - tree_params, max_depth, max_leaves, max_features, n_bins, split_algo, - min_samples_leaf, min_samples_split, min_impurity_decrease, - bootstrap_features, split_criterion, quantile_per_tree, - use_experimental_backend, max_batch_size); + tree_params, max_depth, max_leaves, max_features, n_bins, min_samples_leaf, + min_samples_split, min_impurity_decrease, split_criterion, max_batch_size); RF_params rf_params; rf_params.n_trees = n_trees; rf_params.bootstrap = bootstrap; diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index 0b2480f108..c87a06035b 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -18,9 +18,9 @@ #ifndef _OPENMP #define omp_get_thread_num() 0 #endif -#include #include #include +#include #include #include #include @@ -189,56 +189,17 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, selected_rows[i] = new MLCommon::device_buffer( handle.get_device_allocator(), s, n_sampled_rows); } - - std::shared_ptr> tempmem[n_streams]; - if (this->rf_params.tree_params.use_experimental_backend) { - // TemporaryMemory is unused for batched (new) backend - for (int i = 0; i < n_streams; i++) { - tempmem[i] = nullptr; - } - } else { - // Allocate TemporaryMemory for each stream - for (int i = 0; i < n_streams; i++) { - tempmem[i] = std::make_shared>( - handle, handle.get_internal_stream(i), n_rows, n_cols, n_unique_labels, - this->rf_params.tree_params); - } - } - - std::unique_ptr> global_quantiles_buffer = nullptr; - T* global_quantiles = nullptr; auto quantile_size = this->rf_params.tree_params.n_bins * n_cols; + MLCommon::device_buffer global_quantiles( + handle.get_device_allocator(), handle.get_stream(), quantile_size); //Preprocess once only per forest - if (this->rf_params.tree_params.use_experimental_backend) { - // Using batched backend - // allocate space for d_global_quantiles - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles( - global_quantiles, this->rf_params.tree_params.n_bins, input, n_rows, - n_cols, handle.get_device_allocator(), handle.get_stream()); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - } else { - if ((this->rf_params.tree_params.split_algo == - SPLIT_ALGO::GLOBAL_QUANTILE) && - !(this->rf_params.tree_params.quantile_per_tree)) { - // Using level (old) backend - DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, - this->rf_params.tree_params.n_bins, - tempmem[0]); - for (int i = 1; i < n_streams; i++) { - CUDA_CHECK(cudaMemcpyAsync( - tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T), - cudaMemcpyDeviceToDevice, tempmem[i]->stream)); - memcpy((void*)(tempmem[i]->h_quantile->data()), - (void*)(tempmem[0]->h_quantile->data()), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); - } - } - } + // Using batched backend + // allocate space for d_global_quantiles + DecisionTree::computeQuantiles( + global_quantiles.data(), this->rf_params.tree_params.n_bins, input, n_rows, + n_cols, handle.get_device_allocator(), handle.get_stream()); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); #pragma omp parallel for num_threads(n_streams) for (int i = 0; i < this->rf_params.n_trees; i++) { @@ -263,7 +224,7 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, handle.get_internal_stream(stream_id), input, n_cols, n_rows, labels, rowids, n_sampled_rows, n_unique_labels, tree_ptr, this->rf_params.tree_params, this->rf_params.seed, - global_quantiles, tempmem[stream_id]); + global_quantiles.data()); } //Cleanup for (int i = 0; i < n_streams; i++) { @@ -271,9 +232,6 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, CUDA_CHECK(cudaStreamSynchronize(s)); selected_rows[i]->release(s); delete selected_rows[i]; - if (!this->rf_params.tree_params.use_experimental_backend) { - tempmem[i].reset(); - } } CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); ML::POP_RANGE(); @@ -498,55 +456,14 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, handle.get_device_allocator(), s, n_sampled_rows); } - std::shared_ptr> tempmem[n_streams]; - if (this->rf_params.tree_params.use_experimental_backend) { - // TemporaryMemory is unused for batched (new) backend - for (int i = 0; i < n_streams; i++) { - tempmem[i] = nullptr; - } - } else { - // Allocate TemporaryMemory for each stream - for (int i = 0; i < n_streams; i++) { - tempmem[i] = std::make_shared>( - handle, handle.get_internal_stream(i), n_rows, n_cols, 1, - this->rf_params.tree_params); - } - } - - std::unique_ptr> global_quantiles_buffer = nullptr; - T* global_quantiles = nullptr; auto quantile_size = this->rf_params.tree_params.n_bins * n_cols; + MLCommon::device_buffer global_quantiles( + handle.get_device_allocator(), handle.get_stream(), quantile_size); - //Preprocess once only per forest - if (this->rf_params.tree_params.use_experimental_backend) { - // Using batched backend - // allocate space for d_global_quantiles - global_quantiles_buffer = std::make_unique>( - handle.get_device_allocator(), handle.get_stream(), quantile_size); - global_quantiles = global_quantiles_buffer->data(); - DecisionTree::computeQuantiles( - global_quantiles, this->rf_params.tree_params.n_bins, input, n_rows, - n_cols, handle.get_device_allocator(), handle.get_stream()); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - } else { - if ((this->rf_params.tree_params.split_algo == - SPLIT_ALGO::GLOBAL_QUANTILE) && - !(this->rf_params.tree_params.quantile_per_tree)) { - // Using level (old) backend - DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, - this->rf_params.tree_params.n_bins, - tempmem[0]); - for (int i = 1; i < n_streams; i++) { - CUDA_CHECK(cudaMemcpyAsync( - tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T), - cudaMemcpyDeviceToDevice, tempmem[i]->stream)); - memcpy((void*)(tempmem[i]->h_quantile->data()), - (void*)(tempmem[0]->h_quantile->data()), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); - } - } - } + DecisionTree::computeQuantiles( + global_quantiles.data(), this->rf_params.tree_params.n_bins, input, n_rows, + n_cols, handle.get_device_allocator(), handle.get_stream()); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); #pragma omp parallel for num_threads(n_streams) for (int i = 0; i < this->rf_params.n_trees; i++) { @@ -570,7 +487,7 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, handle.get_internal_stream(stream_id), input, n_cols, n_rows, labels, rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params, this->rf_params.seed, - global_quantiles, tempmem[stream_id]); + global_quantiles.data()); } //Cleanup for (int i = 0; i < n_streams; i++) { @@ -578,9 +495,6 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, CUDA_CHECK(cudaStreamSynchronize(s)); selected_rows[i]->release(s); delete selected_rows[i]; - if (!this->rf_params.tree_params.use_experimental_backend) { - tempmem[i].reset(); - } } CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); ML::POP_RANGE(); diff --git a/cpp/test/sg/decisiontree_batchedlevel_algo.cu b/cpp/test/sg/decisiontree_batchedlevel_algo.cu index 4e608992dc..ffe2ebd9b5 100644 --- a/cpp/test/sg/decisiontree_batchedlevel_algo.cu +++ b/cpp/test/sg/decisiontree_batchedlevel_algo.cu @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include #include @@ -50,9 +49,8 @@ class DtBaseTest : public ::testing::TestWithParam { CUDA_CHECK(cudaStreamCreate(&stream)); handle->set_stream(stream); set_tree_params(params, inparams.max_depth, 1 << inparams.max_depth, 1.f, - inparams.nbins, SPLIT_ALGO::GLOBAL_QUANTILE, 0, - inparams.nbins, inparams.min_gain, false, - inparams.splitType, false, true, 128); + inparams.nbins, 0, inparams.nbins, inparams.min_gain, + inparams.splitType, 128); auto allocator = handle->get_device_allocator(); data = (T*)allocator->allocate(sizeof(T) * inparams.M * inparams.N, stream); labels = (L*)allocator->allocate(sizeof(L) * inparams.M, stream); @@ -160,11 +158,6 @@ TEST_P(DtRegTestF, Test) { inparams.M, labels, quantiles, rowids, inparams.M, 0, params, stream, sparsetree, num_leaves, depth); // goes all the way to max-depth -#if CUDART_VERSION >= 11020 - if (inparams.splitType == CRITERION::MAE) { - GTEST_SKIP(); - } -#endif ASSERT_EQ(depth, inparams.max_depth); } INSTANTIATE_TEST_CASE_P(BatchedLevelAlgo, DtRegTestF, diff --git a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu index e3c3cdc3fb..55dac67b1a 100644 --- a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu +++ b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu @@ -16,7 +16,6 @@ #include -#include #include #include #include @@ -55,15 +54,11 @@ class BatchedLevelAlgoUnitTestFixture { params.max_leaves = 8; params.max_features = 1.0f; params.n_bins = n_bins; - params.split_algo = 1; params.min_samples_leaf = 0; params.min_samples_split = 0; - params.bootstrap_features = false; - params.quantile_per_tree = false; params.split_criterion = CRITERION::MSE; params.min_impurity_decrease = 0.0f; params.max_batch_size = 8; - params.use_experimental_backend = true; h_data = {-1.0f, 0.0f, 2.0f, 0.0f, -2.0f, 0.0f, 1.0f, 0.0f, 3.0f, 0.0f}; // column-major @@ -75,6 +70,8 @@ class BatchedLevelAlgoUnitTestFixture { data = static_cast( d_allocator->allocate(sizeof(DataT) * n_row * n_col, 0)); + d_quantiles = static_cast( + d_allocator->allocate(sizeof(DataT) * n_bins * n_col, 0)); labels = static_cast(d_allocator->allocate(sizeof(LabelT) * n_row, 0)); row_ids = @@ -98,17 +95,12 @@ class BatchedLevelAlgoUnitTestFixture { raft::update_device(data, h_data.data(), n_row * n_col, 0); raft::update_device(labels, h_labels.data(), n_row, 0); + computeQuantiles(d_quantiles, n_bins, data, n_row, n_col, d_allocator, + nullptr); MLCommon::iota(row_ids, 0, 1, n_row, 0); - tempmem = std::make_shared>( - *raft_handle, cudaStream_t(0), n_row, n_col, 0, params); - preprocess_quantile(data, reinterpret_cast(row_ids), n_row, - n_col, n_row, n_bins, tempmem); - DataT* quantiles = tempmem->d_quantile->data(); CUDA_CHECK(cudaStreamSynchronize(0)); - h_quantiles = tempmem->h_quantile->data(); - input.data = data; input.labels = labels; input.M = n_row; @@ -117,12 +109,13 @@ class BatchedLevelAlgoUnitTestFixture { input.nSampledCols = n_col; input.rowids = row_ids; input.nclasses = 0; // not applicable for regression - input.quantiles = quantiles; + input.quantiles = d_quantiles; } void TearDown() { auto d_allocator = raft_handle->get_device_allocator(); d_allocator->deallocate(data, sizeof(DataT) * n_row * n_col, 0); + d_allocator->deallocate(d_quantiles, sizeof(DataT) * n_bins * n_col, 0); d_allocator->deallocate(labels, sizeof(LabelT) * n_row, 0); d_allocator->deallocate(row_ids, sizeof(IdxT) * n_row, 0); d_allocator->deallocate(curr_nodes, sizeof(NodeT) * max_batch, 0); @@ -136,12 +129,11 @@ class BatchedLevelAlgoUnitTestFixture { DecisionTreeParams params; std::unique_ptr raft_handle; - std::shared_ptr> tempmem; std::vector h_data; std::vector h_labels; - DataT* h_quantiles; + DataT* d_quantiles; Traits::InputT input; NodeT* curr_nodes; @@ -156,14 +148,6 @@ class BatchedLevelAlgoUnitTestFixture { IdxT* row_ids; }; -class TestQuantiles : public ::testing::TestWithParam, - protected BatchedLevelAlgoUnitTestFixture { - protected: - void SetUp() override { BatchedLevelAlgoUnitTestFixture::SetUp(); } - - void TearDown() override { BatchedLevelAlgoUnitTestFixture::TearDown(); } -}; - class TestNodeSplitKernel : public ::testing::TestWithParam, protected BatchedLevelAlgoUnitTestFixture { @@ -181,23 +165,6 @@ class TestMetric : public ::testing::TestWithParam, void TearDown() override { BatchedLevelAlgoUnitTestFixture::TearDown(); } }; -TEST_P(TestQuantiles, Quantiles) { - /* Ensure that quantiles are computed correctly */ - std::vector expected_quantiles[]{{-2.0f, -1.0f, 0.0f, 2.0f}, - {0.0f, 1.0f, 3.0f}}; - for (int col = 0; col < n_col; col++) { - std::vector col_quantile(n_bins); - std::copy(h_quantiles + n_bins * col, h_quantiles + n_bins * (col + 1), - col_quantile.begin()); - auto last = std::unique(col_quantile.begin(), col_quantile.end()); - col_quantile.erase(last, col_quantile.end()); - EXPECT_EQ(col_quantile, expected_quantiles[col]); - } -} - -INSTANTIATE_TEST_SUITE_P(BatchedLevelAlgoUnitTest, TestQuantiles, - ::testing::Values(NoOpParams{})); - TEST_P(TestNodeSplitKernel, MinSamplesSplitLeaf) { auto test_params = GetParam(); diff --git a/cpp/test/sg/rf_accuracy_test.cu b/cpp/test/sg/rf_accuracy_test.cu index 5facb71ffa..dda8f9b8dc 100644 --- a/cpp/test/sg/rf_accuracy_test.cu +++ b/cpp/test/sg/rf_accuracy_test.cu @@ -79,27 +79,22 @@ class RFClassifierAccuracyTest : public ::testing::TestWithParam { private: void setRFParams() { - auto algo = SPLIT_ALGO::GLOBAL_QUANTILE; auto sc = CRITERION::CRITERION_END; - rfp = set_rf_params(0, /*max_depth */ - -1, /* max_leaves */ - 1.0, /* max_features */ - 16, /* n_bins */ - algo, /* split_algo */ - 2, /* min_samples_leaf */ - 2, /* min_samples_split */ - 0.f, /* min_impurity_decrease */ - false, /* bootstrap_features */ - true, /* bootstrap */ - 1, /* n_trees */ - 1.0, /* max_samples */ - 0, /* seed */ - sc, /* split_criterion */ - false, /* quantile_per_tree */ - 1, /* n_streams */ - true, /* use_experimental_backend */ - 128 /* max_batch_size */ + rfp = set_rf_params(0, /*max_depth */ + -1, /* max_leaves */ + 1.0, /* max_features */ + 16, /* n_bins */ + 2, /* min_samples_leaf */ + 2, /* min_samples_split */ + 0.f, /* min_impurity_decrease */ + true, /* bootstrap */ + 1, /* n_trees */ + 1.0, /* max_samples */ + 0, /* seed */ + sc, /* split_criterion */ + 1, /* n_streams */ + 128 /* max_batch_size */ ); } diff --git a/cpp/test/sg/rf_batched_classification_test.cu b/cpp/test/sg/rf_batched_classification_test.cu index b158c5d202..caade81b94 100644 --- a/cpp/test/sg/rf_batched_classification_test.cu +++ b/cpp/test/sg/rf_batched_classification_test.cu @@ -33,9 +33,7 @@ struct RfInputs { int max_depth; int max_leaves; bool bootstrap; - bool bootstrap_features; int n_bins; - int split_algo; int min_samples_leaf; int min_samples_split; float min_impurity_decrease; @@ -53,10 +51,9 @@ class RFBatchedClsTest : public ::testing::TestWithParam { RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); CUDA_CHECK(cudaStreamCreate(&stream)); handle.reset(new raft::handle_t(rf_params.n_streams)); @@ -142,17 +139,16 @@ class RFBatchedClsTest : public ::testing::TestWithParam { //------------------------------------------------------------------------------------------------------------------------------------- const std::vector inputsf2_clf = { // Simple non-crash tests with small datasets - {100, 59, 1, 1.0f, 0.4f, 16, -1, true, false, 10, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, CRITERION::GINI, 0.0f}, - {101, 59, 2, 1.0f, 0.4f, 10, -1, true, false, 13, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, CRITERION::GINI, 0.0f}, - {100, 1, 2, 1.0f, 0.4f, 10, -1, true, false, 15, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, CRITERION::GINI, 0.0f}, + {100, 59, 1, 1.0f, 0.4f, 16, -1, true, 10, 2, 2, 0.0, 2, CRITERION::GINI, + 0.0f}, + {101, 59, 2, 1.0f, 0.4f, 10, -1, true, 13, 2, 2, 0.0, 2, CRITERION::GINI, + 0.0f}, + {100, 1, 2, 1.0f, 0.4f, 10, -1, true, 15, 2, 2, 0.0, 2, CRITERION::GINI, + 0.0f}, // Simple accuracy tests - {20000, 10, 25, 1.0f, 0.4f, 16, -1, true, false, 10, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 2, 0.0, 2, CRITERION::GINI}, - {20000, 10, 5, 1.0f, 0.4f, 14, -1, true, false, 10, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 2, 0.0, 2, CRITERION::ENTROPY}}; + {20000, 10, 25, 1.0f, 0.4f, 16, -1, true, 10, 2, 2, 0.0, 2, CRITERION::GINI}, + {20000, 10, 5, 1.0f, 0.4f, 14, -1, true, 10, 2, 2, 0.0, 2, + CRITERION::ENTROPY}}; typedef RFBatchedClsTest RFBatchedClsTestF; TEST_P(RFBatchedClsTestF, Fit) { diff --git a/cpp/test/sg/rf_batched_regression_test.cu b/cpp/test/sg/rf_batched_regression_test.cu index 9eb279b215..47bfa1fae6 100644 --- a/cpp/test/sg/rf_batched_regression_test.cu +++ b/cpp/test/sg/rf_batched_regression_test.cu @@ -37,9 +37,7 @@ struct RfInputs { int max_depth; int max_leaves; bool bootstrap; - bool bootstrap_features; int n_bins; - int split_algo; int min_samples_leaf; int min_samples_split; float min_impurity_decrease; @@ -57,10 +55,9 @@ class RFBatchedRegTest : public ::testing::TestWithParam { RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); CUDA_CHECK(cudaStreamCreate(&stream)); handle.reset(new raft::handle_t(rf_params.n_streams)); @@ -120,15 +117,15 @@ class RFBatchedRegTest : public ::testing::TestWithParam { //------------------------------------------------------------------------------------------------------------------------------------- const std::vector inputs = { - RfInputs{5, 1, 1, 1.0f, 1.0f, 1, -1, false, false, 5, - SPLIT_ALGO::GLOBAL_QUANTILE, 1, 2, 0.0, 1, CRITERION::MSE, -5.0}, + RfInputs{5, 1, 1, 1.0f, 1.0f, 1, -1, false, 5, 1, 2, 0.0, 1, CRITERION::MSE, + -5.0}, // Small datasets to repro corner cases as in #3107 (test for crash) - {101, 57, 2, 1.0f, 1.0f, 2, -1, false, false, 13, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, CRITERION::MSE, -10.0}, + {101, 57, 2, 1.0f, 1.0f, 2, -1, false, 13, 2, 2, 0.0, 2, CRITERION::MSE, + -10.0}, // Larger datasets for accuracy - {2000, 20, 20, 1.0f, 0.6f, 13, -1, true, false, 10, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 2, 0.0, 2, CRITERION::MSE, 0.68f}}; + {2000, 20, 20, 1.0f, 0.6f, 13, -1, true, 10, 2, 2, 0.0, 2, CRITERION::MSE, + 0.68f}}; typedef RFBatchedRegTest RFBatchedRegTestF; TEST_P(RFBatchedRegTestF, Fit) { ASSERT_GT(accuracy, params.min_expected_acc); } diff --git a/cpp/test/sg/rf_depth_test.cu b/cpp/test/sg/rf_depth_test.cu index 610d2f728b..c032163980 100644 --- a/cpp/test/sg/rf_depth_test.cu +++ b/cpp/test/sg/rf_depth_test.cu @@ -34,9 +34,7 @@ struct RfInputs { int max_depth; int max_leaves; bool bootstrap; - bool bootstrap_features; int n_bins; - int split_algo; int min_samples_leaf; int min_samples_split; float min_impurity_decrease; @@ -49,30 +47,15 @@ class RfClassifierDepthTest : public ::testing::TestWithParam { protected: void basicTest() { const int max_depth = ::testing::TestWithParam::GetParam(); - params = RfInputs{10000, - 10, - 1, - 1.0f, - 1.0f, - max_depth, - -1, - false, - false, - 8, - SPLIT_ALGO::GLOBAL_QUANTILE, - 2, - 2, - 0.0, - 2, - CRITERION::ENTROPY}; + params = RfInputs{10000, 10, 1, 1.0f, 1.0f, max_depth, -1, false, + 8, 2, 2, 0.0, 2, CRITERION::ENTROPY}; RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); int data_len = params.n_rows * params.n_cols; raft::allocate(data, data_len); @@ -141,30 +124,15 @@ class RfRegressorDepthTest : public ::testing::TestWithParam { protected: void basicTest() { const int max_depth = ::testing::TestWithParam::GetParam(); - params = RfInputs{5000, - 10, - 1, - 1.0f, - 1.0f, - max_depth, - -1, - false, - false, - 8, - SPLIT_ALGO::GLOBAL_QUANTILE, - 2, - 2, - 0.0, - 2, - CRITERION::MSE}; + params = RfInputs{5000, 10, 1, 1.0f, 1.0f, max_depth, -1, + false, 8, 2, 2, 0.0, 2, CRITERION::MSE}; RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); int data_len = params.n_rows * params.n_cols; raft::allocate(data, data_len); diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index a17fc88a49..4c7bf14fc6 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -34,9 +34,7 @@ struct RfInputs { int max_depth; int max_leaves; bool bootstrap; - bool bootstrap_features; int n_bins; - int split_algo; int min_samples_leaf; int min_samples_split; float min_impurity_decrease; @@ -58,10 +56,9 @@ class RfClassifierTest : public ::testing::TestWithParam> { RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); //-------------------------------------------------------- // Random Forest @@ -158,10 +155,9 @@ class RfRegressorTest : public ::testing::TestWithParam> { RF_params rf_params; rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, false, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); //-------------------------------------------------------- // Random Forest @@ -243,59 +239,41 @@ class RfRegressorTest : public ::testing::TestWithParam> { //------------------------------------------------------------------------------------------------------------------------------------- const std::vector> inputsf2_clf = { - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, // single tree forest, bootstrap false, depth 8, 4 bins - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, + CRITERION::GINI}, // single tree forest, bootstrap false, depth 8, 4 bins + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::GINI}, // single tree forest, bootstrap false, depth of 8, 4 bins - {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION:: GINI}, //forest with 10 trees, all trees should produce identical predictions (no bootstrapping or column subsampling) - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION:: GINI}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 1, 2, 0.0, 1, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 1, CRITERION:: CRITERION_END}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins, different split algorithm - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 1, 2, 0.0, 2, CRITERION::ENTROPY}, - {50, 10, 10, 0.8f, 0.8f, 10, 7, -1, true, true, 3, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 2, 0.0, 2, CRITERION::ENTROPY}}; + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {50, 10, 10, 0.8f, 0.8f, 10, 7, -1, true, 3, 1, 2, 0.0, 2, + CRITERION::ENTROPY}}; const std::vector> inputsd2_clf = { // Same as inputsf2_clf - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, - {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 1, 2, 0.0, 2, CRITERION::CRITERION_END}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 1, 2, 0.0, 2, CRITERION::ENTROPY}, - {50, 10, 10, 0.8f, 0.8f, 10, 7, -1, true, true, 3, - SPLIT_ALGO::GLOBAL_QUANTILE, 2, 2, 0.0, 2, CRITERION::ENTROPY}}; + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::GINI}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::GINI}, + {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::GINI}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION::GINI}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, + CRITERION::CRITERION_END}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 7, -1, true, 3, 1, 2, 0.0, 2, CRITERION::ENTROPY}, + {50, 10, 10, 0.8f, 0.8f, 10, 7, -1, true, 3, 1, 2, 0.0, 2, + CRITERION::ENTROPY}}; typedef RfClassifierTest RfClassifierTestF; TEST_P(RfClassifierTestF, Fit) { @@ -342,26 +320,21 @@ TEST_P(RfRegressorTestD, Fit) { } const std::vector> inputsf2_reg = { - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION:: CRITERION_END}, // CRITERION_END uses the default criterion (GINI for classification, MSE for regression) - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::CRITERION_END}}; + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, 4, 1, 2, 0.0, 2, + CRITERION::CRITERION_END}}; const std::vector> inputsd2_reg = { // Same as inputsf2_reg - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::CRITERION_END}, - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::CRITERION_END}}; + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, 4, 1, 2, 0.0, 2, + CRITERION::CRITERION_END}, + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, 4, 1, 2, 0.0, 2, + CRITERION::CRITERION_END}}; INSTANTIATE_TEST_CASE_P(RfRegressorTests, RfRegressorTestF, ::testing::ValuesIn(inputsf2_reg)); diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu index 3986eb695d..3113705120 100644 --- a/cpp/test/sg/rf_treelite_test.cu +++ b/cpp/test/sg/rf_treelite_test.cu @@ -48,9 +48,7 @@ struct RfInputs { int max_depth; int max_leaves; bool bootstrap; - bool bootstrap_features; int n_bins; - int split_algo; int min_samples_leaf; int min_samples_split; float min_impurity_decrease; @@ -185,10 +183,9 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> { rf_params = set_rf_params( params.max_depth, params.max_leaves, params.max_features, params.n_bins, - params.split_algo, params.min_samples_leaf, params.min_samples_split, - params.min_impurity_decrease, params.bootstrap_features, params.bootstrap, - params.n_trees, params.max_samples, 0, params.split_criterion, false, - params.n_streams, true, 128); + params.min_samples_leaf, params.min_samples_split, + params.min_impurity_decrease, params.bootstrap, params.n_trees, + params.max_samples, 0, params.split_criterion, params.n_streams, 128); handle.reset(new raft::handle_t(rf_params.n_streams)); @@ -432,33 +429,24 @@ class RfConcatTestReg : public RfTreeliteTestCommon { // //------------------------------------------------------------------------------------------------------------------------------------- const std::vector> inputsf2_clf = { - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::GINI}, // single tree forest, bootstrap false, depth 8, 4 bins - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, + CRITERION::GINI}, // single tree forest, bootstrap false, depth 8, 4 bins + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, CRITERION::GINI}, // single tree forest, bootstrap false, depth of 8, 4 bins - {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, CRITERION:: GINI}, //forest with 10 trees, all trees should produce identical predictions (no bootstrapping or column subsampling) - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, 3, 2, 2, 0.0, 2, CRITERION:: GINI}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, 3, 2, 2, 0.0, 2, CRITERION:: CRITERION_END}, //forest with 10 trees, with bootstrap and column subsampling enabled, 3 bins, different split algorithm - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::ENTROPY}, - {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, false, 3, SPLIT_ALGO::GLOBAL_QUANTILE, - 2, 2, 0.0, 2, CRITERION::ENTROPY}}; + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 1, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 1.0f, 1.0f, 4, 8, -1, false, 4, 2, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, 3, 2, 2, 0.0, 2, CRITERION::ENTROPY}, + {4, 2, 10, 0.8f, 0.8f, 4, 8, -1, true, 3, 2, 2, 0.0, 2, CRITERION::ENTROPY}}; typedef RfConcatTestClf RfClassifierConcatTestF; TEST_P(RfClassifierConcatTestF, Convert_Clf) { testClassifier(); } @@ -467,16 +455,13 @@ INSTANTIATE_TEST_CASE_P(RfBinaryClassifierConcatTests, RfClassifierConcatTestF, ::testing::ValuesIn(inputsf2_clf)); const std::vector> inputsf2_reg = { - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::MSE}, - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 2, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 1, 1.0f, 1.0f, 4, 7, -1, false, 4, 2, 2, 0.0, 2, CRITERION::MSE}, + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, false, 4, 2, 2, 0.0, 2, CRITERION:: CRITERION_END}, // CRITERION_END uses the default criterion (GINI for classification, MSE for regression) - {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, false, 4, SPLIT_ALGO::HIST, 2, 2, 0.0, - 2, CRITERION::CRITERION_END}}; + {4, 2, 5, 1.0f, 1.0f, 4, 7, -1, true, 4, 2, 2, 0.0, 2, + CRITERION::CRITERION_END}}; typedef RfConcatTestReg RfRegressorConcatTestF; TEST_P(RfRegressorConcatTestF, Convert_Reg) { testRegressor(); } diff --git a/python/cuml/dask/ensemble/randomforestclassifier.py b/python/cuml/dask/ensemble/randomforestclassifier.py index 8aa4837eb2..d4240cc147 100755 --- a/python/cuml/dask/ensemble/randomforestclassifier.py +++ b/python/cuml/dask/ensemble/randomforestclassifier.py @@ -89,9 +89,6 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin, If set, each tree in the forest is built on a bootstrapped sample with replacement. If False, the whole dataset is used to build each tree. - bootstrap_features : boolean (default = False) - Control bootstrapping for features. - If features are drawn with or without replacement max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = -1) diff --git a/python/cuml/dask/ensemble/randomforestregressor.py b/python/cuml/dask/ensemble/randomforestregressor.py index 0da3f08aca..3b21810fb4 100755 --- a/python/cuml/dask/ensemble/randomforestregressor.py +++ b/python/cuml/dask/ensemble/randomforestregressor.py @@ -81,9 +81,6 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin, If set, each tree in the forest is built on a bootstrapped sample with replacement. If False, the whole dataset is used to build each tree. - bootstrap_features : boolean (default = False) - Control bootstrapping for features. - If features are drawn with or without replacement max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = -1) diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 03b6c5cf00..1a8a70ba32 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -44,9 +44,9 @@ class BaseRandomForestModel(Base): 'split_algo', 'split_criterion', 'min_samples_leaf', 'min_samples_split', 'min_impurity_decrease', - 'bootstrap', 'bootstrap_features', + 'bootstrap', 'verbose', 'max_samples', - 'max_leaves', 'quantile_per_tree', + 'max_leaves', 'accuracy_metric', 'use_experimental_backend', 'max_batch_size'] @@ -56,14 +56,14 @@ class BaseRandomForestModel(Base): classes_ = CumlArrayDescriptor() def __init__(self, *, split_criterion, n_streams=8, n_estimators=100, - max_depth=16, handle=None, max_features='auto', n_bins=8, - split_algo=1, bootstrap=True, bootstrap_features=False, + max_depth=16, handle=None, max_features='auto', n_bins=128, + split_algo=1, bootstrap=True, verbose=False, min_samples_leaf=1, min_samples_split=2, max_samples=1.0, max_leaves=-1, accuracy_metric=None, dtype=None, output_type=None, min_weight_fraction_leaf=None, n_jobs=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, oob_score=None, random_state=None, - warm_start=None, class_weight=None, quantile_per_tree=False, + warm_start=None, class_weight=None, criterion=None, use_experimental_backend=True, max_batch_size=128): @@ -91,10 +91,12 @@ class BaseRandomForestModel(Base): "recommended. If n_streams is > 1, results may vary " "due to stream/thread timing differences, even when " "random_state is set") - if quantile_per_tree: - warnings.warn("The 'quantile_per_tree' parameter is deprecated " - "and will be removed in 21.06 release. Instead use " - "higher number of global quantile bins.") + if use_experimental_backend: + warnings.warn("The 'use_experimental_backend' parameter is " + "deprecated and will be removed in 21.10 release.") + if split_algo: + warnings.warn("The 'split_algo' parameter is deprecated " + "and will be removed in 21.10 release.") if handle is None: handle = Handle(n_streams) @@ -106,7 +108,6 @@ class BaseRandomForestModel(Base): if max_depth < 0: raise ValueError("Must specify max_depth >0 ") - self.split_algo = split_algo if (str(split_criterion) not in BaseRandomForestModel.criterion_dict.keys()): warnings.warn("The split criterion chosen was not present" @@ -120,7 +121,6 @@ class BaseRandomForestModel(Base): self.min_samples_leaf = min_samples_leaf self.min_samples_split = min_samples_split self.min_impurity_decrease = min_impurity_decrease - self.bootstrap_features = bootstrap_features self.max_samples = max_samples self.max_leaves = max_leaves self.n_estimators = n_estimators @@ -131,8 +131,6 @@ class BaseRandomForestModel(Base): self.n_cols = None self.dtype = dtype self.accuracy_metric = accuracy_metric - self.quantile_per_tree = quantile_per_tree - self.use_experimental_backend = use_experimental_backend self.max_batch_size = max_batch_size self.n_streams = handle.getNumInternalStreams() self.random_state = random_state diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index ce9ce05a76..9e3c23fb4f 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -96,17 +96,13 @@ cdef extern from "cuml/ensemble/randomforest.hpp" namespace "ML": int, int, int, - int, float, bool, - bool, int, float, uint64_t, CRITERION, - bool, int, - bool, int) except + cdef vector[unsigned char] save_model(ModelHandle) diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index f30ae09c7f..b84cdb4690 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -197,24 +197,13 @@ class RandomForestClassifier(BaseRandomForestModel, 2 and 3 not valid for classification (default = 0) split_algo : int (default = 1) - The algorithm to determine how nodes are split in the tree. - Can be changed only for the old backend [deprecated]. - 0 for HIST and 1 for GLOBAL_QUANTILE. Default is GLOBAL_QUANTILE. - The default backend does not support HIST. - HIST currently uses a slower tree-building algorithm so - GLOBAL_QUANTILE is recommended for most cases. - + Deprecated and currrently has no effect. .. deprecated:: 21.06 - Parameter 'split_algo' is deprecated and will be removed in - subsequent release. bootstrap : boolean (default = True) Control bootstrapping. If True, each tree in the forest is built on a bootstrapped sample with replacement. If False, the whole dataset is used to build each tree. - bootstrap_features : boolean (default = False) - Control bootstrapping for features. - If features are drawn with or without replacement max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -251,20 +240,9 @@ class RandomForestClassifier(BaseRandomForestModel, min_impurity_decrease : float (default = 0.0) Minimum decrease in impurity requried for node to be spilt. - quantile_per_tree : boolean (default = False) - Whether quantile is computed for individual trees in RF. - Only relevant when `split_algo = GLOBAL_QUANTILE`. - - .. deprecated:: 0.19 - Parameter 'quantile_per_tree' is deprecated and will be removed in - subsequent release. use_experimental_backend : boolean (default = True) - If set to true and the following conditions are also met, a new - experimental backend for decision tree training will be used. The - new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE) - and `quantile_per_tree = False` (No per tree quantile computation). - The new backend is now considered stable for both classification - and regression tasks and is significantly faster than the old backend. + Deprecated and currrently has no effect. + .. deprecated:: 21.08 max_batch_size: int (default = 128) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. Does not currently @@ -293,7 +271,7 @@ class RandomForestClassifier(BaseRandomForestModel, """ def __init__(self, *, split_criterion=0, handle=None, verbose=False, - output_type=None, n_bins=128, use_experimental_backend=True, + output_type=None, **kwargs): self.RF_type = CLASSIFICATION @@ -303,8 +281,6 @@ class RandomForestClassifier(BaseRandomForestModel, handle=handle, verbose=verbose, output_type=output_type, - n_bins=n_bins, - use_experimental_backend=use_experimental_backend, **kwargs) """ @@ -503,19 +479,15 @@ class RandomForestClassifier(BaseRandomForestModel, self.max_leaves, max_feature_val, self.n_bins, - self.split_algo, self.min_samples_leaf, self.min_samples_split, self.min_impurity_decrease, - self.bootstrap_features, self.bootstrap, self.n_estimators, self.max_samples, seed_val, self.split_criterion, - self.quantile_per_tree, self.n_streams, - self.use_experimental_backend, self.max_batch_size) if self.dtype == np.float32: diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 7de9b8cf7a..794eadbcb4 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -183,9 +183,6 @@ class RandomForestRegressor(BaseRandomForestModel, If True, each tree in the forest is built on a bootstrapped sample with replacement. If False, the whole dataset is used to build each tree. - bootstrap_features : boolean (default = False) - Control bootstrapping for features. - If features are drawn with or without replacement max_samples : float (default = 1.0) Ratio of dataset rows used while fitting each tree. max_depth : int (default = 16) @@ -230,20 +227,9 @@ class RandomForestRegressor(BaseRandomForestModel, for median of abs error : 'median_ae' for mean of abs error : 'mean_ae' for mean square error' : 'mse' - quantile_per_tree : boolean (default = False) - Whether quantile is computed for individual trees in RF. - Only relevant when `split_algo = GLOBAL_QUANTILE`. - - .. deprecated:: 0.19 - Parameter 'quantile_per_tree' is deprecated and will be removed in - subsequent release. use_experimental_backend : boolean (default = True) - If set to true and the following conditions are also met, a new - experimental backend for decision tree training will be used. The - new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE) - and `quantile_per_tree = False` (No per tree quantile computation). - The new backend is now considered stable for both classification - and regression tasks and is significantly faster than the old backend. + Deprecated and currrently has no effect. + .. deprecated:: 21.08 max_batch_size: int (default = 128) Maximum number of nodes that can be processed in a given batch. This is used only when 'use_experimental_backend' is true. @@ -464,19 +450,15 @@ class RandomForestRegressor(BaseRandomForestModel, self.max_leaves, max_feature_val, self.n_bins, - self.split_algo, self.min_samples_leaf, self.min_samples_split, self.min_impurity_decrease, - self.bootstrap_features, self.bootstrap, self.n_estimators, self.max_samples, seed_val, self.split_criterion, - self.quantile_per_tree, self.n_streams, - self.use_experimental_backend, self.max_batch_size) if self.dtype == np.float32: diff --git a/python/cuml/test/dask/test_random_forest.py b/python/cuml/test/dask/test_random_forest.py index 1d3850f241..71595a52ef 100644 --- a/python/cuml/test/dask/test_random_forest.py +++ b/python/cuml/test/dask/test_random_forest.py @@ -382,14 +382,14 @@ def test_rf_get_json(client, estimator_type, max_depth, n_estimators): X = X.astype(np.float32) if estimator_type == 'classification': cu_rf_mg = cuRFC_mg(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=0, split_criterion=0, + n_bins=16, split_criterion=0, min_samples_leaf=2, random_state=23707, n_streams=1, n_estimators=n_estimators, max_leaves=-1, max_depth=max_depth) y = y.astype(np.int32) elif estimator_type == 'regression': cu_rf_mg = cuRFR_mg(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=0, + n_bins=16, min_samples_leaf=2, random_state=23707, n_streams=1, n_estimators=n_estimators, max_leaves=-1, max_depth=max_depth) @@ -417,7 +417,7 @@ def predict_with_json_tree(tree, x): assert 'split_threshold' in tree assert 'yes' in tree assert 'no' in tree - if x[tree['split_feature']] <= tree['split_threshold']: + if x[tree['split_feature']] <= tree['split_threshold'] + 1e-5: return predict_with_json_tree(tree['children'][0], x) return predict_with_json_tree(tree['children'][1], x) @@ -470,7 +470,7 @@ def test_rf_instance_count(client, max_depth, n_estimators): n_classes=2) X = X.astype(np.float32) cu_rf_mg = cuRFC_mg(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=1, split_criterion=0, + n_bins=16, split_criterion=0, min_samples_leaf=2, random_state=23707, n_streams=1, n_estimators=n_estimators, max_leaves=-1, max_depth=max_depth) diff --git a/python/cuml/test/explainer/test_explainer_kernel_shap.py b/python/cuml/test/explainer/test_explainer_kernel_shap.py index 2c20d005fa..219801f384 100644 --- a/python/cuml/test/explainer/test_explainer_kernel_shap.py +++ b/python/cuml/test/explainer/test_explainer_kernel_shap.py @@ -376,10 +376,30 @@ def test_l1_regularization(exact_shap_regression_dataset, l1_type): ] housing_regression_result = np.array( - [[-0.7222878, 0.00888237, -0.07044561, -0.02764106, -0.01486777, - -0.19961227, -0.1367276, -0.11073875], - [-0.688218, 0.04260924, -0.12853414, 0.06109668, -0.01486243, - -0.0627693, -0.17290883, -0.02488524]], dtype=np.float32) + [ + [ + -0.73860609, + 0.00557072, + -0.05829297, + -0.01582018, + -0.01010366, + -0.23167623, + -0.470639, + -0.07584473, + ], + [ + -0.6410764, + 0.01369913, + -0.09492759, + 0.02654463, + -0.00911134, + -0.05953105, + -0.51266433, + -0.0853608, + ], + ], + dtype=np.float32, +) cuml_skl_class_dict = { cuml.LinearRegression: sklearn.linear_model.LinearRegression, diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index 07dc6629a5..43a42d8132 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -27,8 +27,7 @@ from cuml.ensemble import RandomForestClassifier as curfc from cuml.ensemble import RandomForestRegressor as curfr from cuml.metrics import r2_score -from cuml.test.utils import get_handle, unit_param, \ - quality_param, stress_param +from cuml.test.utils import get_handle, unit_param, quality_param, stress_param import cuml.common.logger as logger from sklearn.ensemble import RandomForestClassifier as skrfc @@ -45,163 +44,200 @@ @pytest.fixture( scope="session", params=[ - unit_param({'n_samples': 350, 'n_features': 20, 'n_informative': 10}), - quality_param({'n_samples': 5000, 'n_features': 200, - 'n_informative': 80}), - stress_param({'n_samples': 500000, 'n_features': 400, - 'n_informative': 180}) - ]) + unit_param({"n_samples": 350, "n_features": 20, "n_informative": 10}), + quality_param( + {"n_samples": 5000, "n_features": 200, "n_informative": 80} + ), + stress_param( + {"n_samples": 500000, "n_features": 400, "n_informative": 180} + ), + ], +) def small_clf(request): - X, y = make_classification(n_samples=request.param['n_samples'], - n_features=request.param['n_features'], - n_clusters_per_class=1, - n_informative=request.param['n_informative'], - random_state=123, n_classes=2) + X, y = make_classification( + n_samples=request.param["n_samples"], + n_features=request.param["n_features"], + n_clusters_per_class=1, + n_informative=request.param["n_informative"], + random_state=123, + n_classes=2, + ) return X, y @pytest.fixture( scope="session", params=[ - unit_param({'n_samples': 350, 'n_features': 30, 'n_informative': 15}), - quality_param({'n_samples': 5000, 'n_features': 200, - 'n_informative': 80}), - stress_param({'n_samples': 500000, 'n_features': 400, - 'n_informative': 180}) - ]) + unit_param({"n_samples": 350, "n_features": 30, "n_informative": 15}), + quality_param( + {"n_samples": 5000, "n_features": 200, "n_informative": 80} + ), + stress_param( + {"n_samples": 500000, "n_features": 400, "n_informative": 180} + ), + ], +) def mclass_clf(request): - X, y = make_classification(n_samples=request.param['n_samples'], - n_features=request.param['n_features'], - n_clusters_per_class=1, - n_informative=request.param['n_informative'], - random_state=123, n_classes=10) + X, y = make_classification( + n_samples=request.param["n_samples"], + n_features=request.param["n_features"], + n_clusters_per_class=1, + n_informative=request.param["n_informative"], + random_state=123, + n_classes=10, + ) return X, y @pytest.fixture( scope="session", params=[ - unit_param({'n_samples': 500, 'n_features': 20, 'n_informative': 10}), - quality_param({'n_samples': 5000, 'n_features': 200, - 'n_informative': 50}), - stress_param({'n_samples': 500000, 'n_features': 400, - 'n_informative': 100}) - ]) + unit_param({"n_samples": 500, "n_features": 20, "n_informative": 10}), + quality_param( + {"n_samples": 5000, "n_features": 200, "n_informative": 50} + ), + stress_param( + {"n_samples": 500000, "n_features": 400, "n_informative": 100} + ), + ], +) def large_clf(request): - X, y = make_classification(n_samples=request.param['n_samples'], - n_features=request.param['n_features'], - n_clusters_per_class=1, - n_informative=request.param['n_informative'], - random_state=123, n_classes=2) + X, y = make_classification( + n_samples=request.param["n_samples"], + n_features=request.param["n_features"], + n_clusters_per_class=1, + n_informative=request.param["n_informative"], + random_state=123, + n_classes=2, + ) return X, y @pytest.fixture( scope="session", params=[ - unit_param({'n_samples': 1500, 'n_features': 20, 'n_informative': 10}), - quality_param({'n_samples': 12000, 'n_features': 200, - 'n_informative': 100}), - stress_param({'n_samples': 500000, 'n_features': 500, - 'n_informative': 350}) - ]) + unit_param({"n_samples": 1500, "n_features": 20, "n_informative": 10}), + quality_param( + {"n_samples": 12000, "n_features": 200, "n_informative": 100} + ), + stress_param( + {"n_samples": 500000, "n_features": 500, "n_informative": 350} + ), + ], +) def large_reg(request): - X, y = make_regression(n_samples=request.param['n_samples'], - n_features=request.param['n_features'], - n_informative=request.param['n_informative'], - random_state=123) + X, y = make_regression( + n_samples=request.param["n_samples"], + n_features=request.param["n_features"], + n_informative=request.param["n_informative"], + random_state=123, + ) return X, y special_reg_params = [ - unit_param({'mode': 'unit', 'n_samples': 500, - 'n_features': 20, 'n_informative': 10}), - quality_param({'mode': 'quality', 'n_samples': 500, - 'n_features': 20, 'n_informative': 10}), - quality_param({'mode': 'quality', 'n_features': 200, - 'n_informative': 50}), - stress_param({'mode': 'stress', 'n_samples': 500, - 'n_features': 20, 'n_informative': 10}), - stress_param({'mode': 'stress', 'n_features': 200, - 'n_informative': 50}), - stress_param({'mode': 'stress', 'n_samples': 1000, - 'n_features': 400, 'n_informative': 100}) - ] - - -@pytest.fixture( - scope="session", - params=special_reg_params) + unit_param( + { + "mode": "unit", + "n_samples": 500, + "n_features": 20, + "n_informative": 10, + } + ), + quality_param( + { + "mode": "quality", + "n_samples": 500, + "n_features": 20, + "n_informative": 10, + } + ), + quality_param({"mode": "quality", "n_features": 200, "n_informative": 50}), + stress_param( + { + "mode": "stress", + "n_samples": 500, + "n_features": 20, + "n_informative": 10, + } + ), + stress_param({"mode": "stress", "n_features": 200, "n_informative": 50}), + stress_param( + { + "mode": "stress", + "n_samples": 1000, + "n_features": 400, + "n_informative": 100, + } + ), +] + + +@pytest.fixture(scope="session", params=special_reg_params) def special_reg(request): - if request.param['mode'] == 'quality': + if request.param["mode"] == "quality": X, y = fetch_california_housing(return_X_y=True) else: - X, y = make_regression(n_samples=request.param['n_samples'], - n_features=request.param['n_features'], - n_informative=request.param['n_informative'], - random_state=123) + X, y = make_regression( + n_samples=request.param["n_samples"], + n_features=request.param["n_features"], + n_informative=request.param["n_informative"], + random_state=123, + ) return X, y -@pytest.mark.parametrize('max_samples', [unit_param(1.0), quality_param(0.90), - stress_param(0.95)]) -@pytest.mark.parametrize('datatype', [np.float32]) -@pytest.mark.parametrize('split_algo', [0, 1]) -@pytest.mark.parametrize('max_features', [1.0, 'auto', 'log2', 'sqrt']) -@pytest.mark.parametrize('use_experimental_backend', [True, False]) -def test_rf_classification(small_clf, datatype, split_algo, - max_samples, max_features, - use_experimental_backend): +@pytest.mark.parametrize( + "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)] +) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"]) +def test_rf_classification(small_clf, datatype, max_samples, max_features): use_handle = True X, y = small_clf X = X.astype(datatype) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Create a handle for the cuml model handle, stream = get_handle(use_handle, n_streams=1) # Initialize, fit and predict using cuML's # random forest classification model - cuml_model = curfc(max_features=max_features, max_samples=max_samples, - n_bins=16, split_algo=split_algo, split_criterion=0, - min_samples_leaf=2, random_state=123, n_streams=1, - n_estimators=40, handle=handle, max_leaves=-1, - max_depth=16, - use_experimental_backend=use_experimental_backend) + cuml_model = curfc( + max_features=max_features, + max_samples=max_samples, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=123, + n_streams=1, + n_estimators=40, + handle=handle, + max_leaves=-1, + max_depth=16, + ) f = io.StringIO() with redirect_stdout(f): cuml_model.fit(X_train, y_train) - captured_stdout = f.getvalue() - - is_fallback_used = False - if split_algo != 1 and use_experimental_backend: - assert ('Experimental backend does not yet support histogram ' + - 'split algorithm' in captured_stdout) - is_fallback_used = True - if is_fallback_used: - assert ('Not using the experimental backend due to above ' + - 'mentioned reason(s)' in captured_stdout) - if not use_experimental_backend: - assert('The old backend is deprecated and will be removed in 21.08 release.' # noqa: E501 - in captured_stdout) - assert('Using old backend for growing trees' - in captured_stdout) - - fil_preds = cuml_model.predict(X_test, - predict_model="GPU", - threshold=0.5, - algo='auto') + + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", threshold=0.5, algo="auto" + ) cu_preds = cuml_model.predict(X_test, predict_model="CPU") fil_preds = np.reshape(fil_preds, np.shape(cu_preds)) cuml_acc = accuracy_score(y_test, cu_preds) fil_acc = accuracy_score(y_test, fil_preds) if X.shape[0] < 500000: - sk_model = skrfc(n_estimators=40, - max_depth=16, - min_samples_split=2, max_features=max_features, - random_state=10) + sk_model = skrfc( + n_estimators=40, + max_depth=16, + min_samples_split=2, + max_features=max_features, + random_state=10, + ) sk_model.fit(X_train, y_train) sk_preds = sk_model.predict(X_test) sk_acc = accuracy_score(y_test, sk_preds) @@ -209,63 +245,54 @@ def test_rf_classification(small_clf, datatype, split_algo, assert fil_acc >= (cuml_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa -@pytest.mark.parametrize('max_samples', [unit_param(1.0), quality_param(0.90), - stress_param(0.95)]) -@pytest.mark.parametrize('datatype', [np.float32]) @pytest.mark.parametrize( - 'split_algo,max_features,use_experimental_backend,n_bins', - [(0, 1.0, False, 16), - (1, 1.0, False, 11), - (0, 'auto', False, 128), - (1, 'log2', False, 100), - (1, 'sqrt', False, 100), - (1, 1.0, True, 17), - (1, 1.0, True, 32), - (0, 1.0, True, 16), - (1, 1.0, True, 11), - (0, 'auto', True, 128), - (1, 1.0, True, 100), - (1, 'log2', True, 100), - (1, 'sqrt', True, 100), - ]) -def test_rf_regression(special_reg, datatype, split_algo, max_features, - max_samples, use_experimental_backend, n_bins): + "max_samples", [unit_param(1.0), quality_param(0.90), stress_param(0.95)] +) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize( + "max_features,n_bins", + [ + (1.0, 16), + (1.0, 11), + ("auto", 128), + ("log2", 100), + ("sqrt", 100), + (1.0, 17), + (1.0, 32), + ], +) +def test_rf_regression( + special_reg, datatype, max_features, max_samples, n_bins +): + use_handle = True X, y = special_reg X = X.astype(datatype) y = y.astype(datatype) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Create a handle for the cuml model handle, stream = get_handle(use_handle, n_streams=1) # Initialize and fit using cuML's random forest regression model - cuml_model = curfr(max_features=max_features, max_samples=max_samples, - n_bins=n_bins, split_algo=split_algo, split_criterion=2, - min_samples_leaf=2, random_state=123, n_streams=1, - n_estimators=50, handle=handle, max_leaves=-1, - max_depth=16, accuracy_metric='mse', - use_experimental_backend=use_experimental_backend) - f = io.StringIO() - with redirect_stdout(f): - cuml_model.fit(X_train, y_train) - captured_stdout = f.getvalue() - - is_fallback_used = False - if split_algo != 1 and use_experimental_backend: - assert ('Experimental backend does not yet support histogram ' + - 'split algorithm' in captured_stdout) - is_fallback_used = True - if is_fallback_used: - assert ('Not using the experimental backend due to above ' + - 'mentioned reason(s)' in captured_stdout) - if not use_experimental_backend: - assert('The old backend is deprecated and will be removed in 21.08 release.' # noqa: E501 - in captured_stdout) - assert('Using old backend for growing trees' - in captured_stdout) + cuml_model = curfr( + max_features=max_features, + max_samples=max_samples, + n_bins=n_bins, + split_criterion=2, + min_samples_leaf=2, + random_state=123, + n_streams=1, + n_estimators=50, + handle=handle, + max_leaves=-1, + max_depth=16, + accuracy_metric="mse", + ) + cuml_model.fit(X_train, y_train) # predict using FIL fil_preds = cuml_model.predict(X_test, predict_model="GPU") cu_preds = cuml_model.predict(X_test, predict_model="CPU") @@ -276,9 +303,13 @@ def test_rf_regression(special_reg, datatype, split_algo, max_features, # Initialize, fit and predict using # sklearn's random forest regression model if X.shape[0] < 1000: # mode != "stress" - sk_model = skrfr(n_estimators=50, max_depth=16, - min_samples_split=2, max_features=max_features, - random_state=10) + sk_model = skrfr( + n_estimators=50, + max_depth=16, + min_samples_split=2, + max_features=max_features, + random_state=10, + ) sk_model.fit(X_train, y_train) sk_preds = sk_model.predict(X_test) sk_r2 = r2_score(y_test, sk_preds, convert_dtype=datatype) @@ -286,14 +317,15 @@ def test_rf_regression(special_reg, datatype, split_algo, max_features, assert fil_r2 >= (cu_r2 - 0.02) -@pytest.mark.parametrize('datatype', [np.float32]) +@pytest.mark.parametrize("datatype", [np.float32]) def test_rf_classification_seed(small_clf, datatype): X, y = small_clf X = X.astype(datatype) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) for i in range(8): seed = random.randint(100, 1e5) @@ -303,10 +335,8 @@ def test_rf_classification_seed(small_clf, datatype): cu_class.fit(X_train, y_train) # predict using FIL - fil_preds_orig = cu_class.predict(X_test, - predict_model="GPU") - cu_preds_orig = cu_class.predict(X_test, - predict_model="CPU") + fil_preds_orig = cu_class.predict(X_test, predict_model="GPU") + cu_preds_orig = cu_class.predict(X_test, predict_model="CPU") cu_acc_orig = accuracy_score(y_test, cu_preds_orig) fil_preds_orig = np.reshape(fil_preds_orig, np.shape(cu_preds_orig)) @@ -318,8 +348,7 @@ def test_rf_classification_seed(small_clf, datatype): cu_class2.fit(X_train, y_train) # predict using FIL - fil_preds_rerun = cu_class2.predict(X_test, - predict_model="GPU") + fil_preds_rerun = cu_class2.predict(X_test, predict_model="GPU") cu_preds_rerun = cu_class2.predict(X_test, predict_model="CPU") cu_acc_rerun = accuracy_score(y_test, cu_preds_rerun) fil_preds_rerun = np.reshape(fil_preds_rerun, np.shape(cu_preds_rerun)) @@ -332,16 +361,18 @@ def test_rf_classification_seed(small_clf, datatype): assert (cu_preds_orig == cu_preds_rerun).all() -@pytest.mark.parametrize('datatype', [(np.float64, np.float32), - (np.float32, np.float64)]) -@pytest.mark.parametrize('convert_dtype', [True, False]) +@pytest.mark.parametrize( + "datatype", [(np.float64, np.float32), (np.float32, np.float64)] +) +@pytest.mark.parametrize("convert_dtype", [True, False]) def test_rf_classification_float64(small_clf, datatype, convert_dtype): X, y = small_clf X = X.astype(datatype[0]) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) X_test = X_test.astype(datatype[1]) # Initialize, fit and predict using cuML's @@ -362,25 +393,29 @@ def test_rf_classification_float64(small_clf, datatype, convert_dtype): # predict using cuML's GPU based prediction if datatype[0] == np.float32 and convert_dtype: - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - convert_dtype=convert_dtype) + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", convert_dtype=convert_dtype + ) fil_preds = np.reshape(fil_preds, np.shape(cu_preds)) fil_acc = accuracy_score(y_test, fil_preds) assert fil_acc >= (cu_acc - 0.07) # to be changed to 0.02. see issue #3910: https://github.com/rapidsai/cuml/issues/3910 # noqa else: with pytest.raises(TypeError): - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - convert_dtype=convert_dtype) + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", convert_dtype=convert_dtype + ) -@pytest.mark.parametrize('datatype', [(np.float64, np.float32), - (np.float32, np.float64)]) +@pytest.mark.parametrize( + "datatype", [(np.float64, np.float32), (np.float32, np.float64)] +) def test_rf_regression_float64(large_reg, datatype): X, y = large_reg - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) X_train = X_train.astype(datatype[0]) y_train = y_train.astype(datatype[0]) X_test = X_test.astype(datatype[1]) @@ -404,16 +439,18 @@ def test_rf_regression_float64(large_reg, datatype): # predict using cuML's GPU based prediction if datatype[0] == np.float32: - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - convert_dtype=True) + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", convert_dtype=True + ) fil_preds = np.reshape(fil_preds, np.shape(cu_preds)) fil_r2 = r2_score(y_test, fil_preds, convert_dtype=datatype[0]) assert fil_r2 >= (cu_r2 - 0.02) # because datatype[0] != np.float32 or datatype[0] != datatype[1] with pytest.raises(TypeError): - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - convert_dtype=False) + fil_preds = cuml_model.predict( + X_test, predict_model="GPU", convert_dtype=False + ) def check_predict_proba(test_proba, baseline_proba, y_test, rel_err): @@ -427,41 +464,54 @@ def check_predict_proba(test_proba, baseline_proba, y_test, rel_err): assert test_mse <= baseline_mse * (1.0 + rel_err) -def rf_classification(datatype, array_type, max_features, max_samples, - fixture): +def rf_classification( + datatype, array_type, max_features, max_samples, fixture +): X, y = fixture X = X.astype(datatype[0]) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) X_test = X_test.astype(datatype[1]) handle, stream = get_handle(True, n_streams=1) # Initialize, fit and predict using cuML's # random forest classification model - cuml_model = curfc(max_features=max_features, max_samples=max_samples, - n_bins=16, split_criterion=0, - min_samples_leaf=2, random_state=123, - n_estimators=40, handle=handle, max_leaves=-1, - max_depth=16) - if array_type == 'dataframe': + cuml_model = curfc( + max_features=max_features, + max_samples=max_samples, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=123, + n_estimators=40, + handle=handle, + max_leaves=-1, + max_depth=16, + ) + if array_type == "dataframe": X_train_df = cudf.DataFrame(X_train) y_train_df = cudf.Series(y_train) X_test_df = cudf.DataFrame(X_test) cuml_model.fit(X_train_df, y_train_df) - cu_proba_gpu = np.array(cuml_model.predict_proba(X_test_df) - .as_gpu_matrix()) - cu_preds_cpu = cuml_model.predict(X_test_df, - predict_model="CPU").to_array() - cu_preds_gpu = cuml_model.predict(X_test_df, - predict_model="GPU").to_array() + cu_proba_gpu = np.array( + cuml_model.predict_proba(X_test_df).as_gpu_matrix() + ) + cu_preds_cpu = cuml_model.predict( + X_test_df, predict_model="CPU" + ).to_array() + cu_preds_gpu = cuml_model.predict( + X_test_df, predict_model="GPU" + ).to_array() else: cuml_model.fit(X_train, y_train) cu_proba_gpu = cuml_model.predict_proba(X_test) cu_preds_cpu = cuml_model.predict(X_test, predict_model="CPU") cu_preds_gpu = cuml_model.predict(X_test, predict_model="GPU") - np.testing.assert_array_equal(cu_preds_gpu, - np.argmax(cu_proba_gpu, axis=1)) + np.testing.assert_array_equal( + cu_preds_gpu, np.argmax(cu_proba_gpu, axis=1) + ) cu_acc_cpu = accuracy_score(y_test, cu_preds_cpu) cu_acc_gpu = accuracy_score(y_test, cu_preds_gpu) @@ -470,10 +520,13 @@ def rf_classification(datatype, array_type, max_features, max_samples, # sklearn random forest classification model # initialization, fit and predict if y.size < 500000: - sk_model = skrfc(n_estimators=40, - max_depth=16, - min_samples_split=2, max_features=max_features, - random_state=10) + sk_model = skrfc( + n_estimators=40, + max_depth=16, + min_samples_split=2, + max_features=max_features, + random_state=10, + ) sk_model.fit(X_train, y_train) sk_preds = sk_model.predict(X_test) sk_acc = accuracy_score(y_test, sk_preds) @@ -485,67 +538,84 @@ def rf_classification(datatype, array_type, max_features, max_samples, check_predict_proba(cu_proba_gpu, sk_proba, y_test, 0.1) -@pytest.mark.parametrize('datatype', [(np.float32, np.float32)]) -@pytest.mark.parametrize('array_type', ['dataframe', 'numpy']) +@pytest.mark.parametrize("datatype", [(np.float32, np.float32)]) +@pytest.mark.parametrize("array_type", ["dataframe", "numpy"]) def test_rf_classification_multi_class(mclass_clf, datatype, array_type): rf_classification(datatype, array_type, 1.0, 1.0, mclass_clf) -@pytest.mark.parametrize('datatype', [(np.float32, np.float32)]) -@pytest.mark.parametrize('max_samples', [unit_param(1.0), - stress_param(0.95)]) -@pytest.mark.parametrize('max_features', [1.0, 'auto', 'log2', 'sqrt']) -def test_rf_classification_proba(small_clf, datatype, - max_samples, max_features): - rf_classification(datatype, 'numpy', max_features, max_samples, - small_clf) +@pytest.mark.parametrize("datatype", [(np.float32, np.float32)]) +@pytest.mark.parametrize("max_samples", [unit_param(1.0), stress_param(0.95)]) +@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"]) +def test_rf_classification_proba( + small_clf, datatype, max_samples, max_features +): + rf_classification(datatype, "numpy", max_features, max_samples, small_clf) -@pytest.mark.parametrize('datatype', [np.float32]) -@pytest.mark.parametrize('fil_sparse_format', ['not_supported', True, - 'auto', False]) -@pytest.mark.parametrize('algo', ['auto', 'naive', 'tree_reorg', - 'batch_tree_reorg']) -def test_rf_classification_sparse(small_clf, datatype, - fil_sparse_format, algo): +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize( + "fil_sparse_format", ["not_supported", True, "auto", False] +) +@pytest.mark.parametrize( + "algo", ["auto", "naive", "tree_reorg", "batch_tree_reorg"] +) +def test_rf_classification_sparse( + small_clf, datatype, fil_sparse_format, algo +): use_handle = True num_treees = 50 X, y = small_clf X = X.astype(datatype) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Create a handle for the cuml model handle, stream = get_handle(use_handle, n_streams=1) # Initialize, fit and predict using cuML's # random forest classification model - cuml_model = curfc(n_bins=16, split_criterion=0, - min_samples_leaf=2, random_state=123, n_streams=1, - n_estimators=num_treees, handle=handle, max_leaves=-1, - max_depth=40) + cuml_model = curfc( + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=123, + n_streams=1, + n_estimators=num_treees, + handle=handle, + max_leaves=-1, + max_depth=40, + ) cuml_model.fit(X_train, y_train) - if ((not fil_sparse_format or algo == 'tree_reorg' or - algo == 'batch_tree_reorg') or - fil_sparse_format == 'not_supported'): + if ( + not fil_sparse_format + or algo == "tree_reorg" + or algo == "batch_tree_reorg" + ) or fil_sparse_format == "not_supported": with pytest.raises(ValueError): - fil_preds = cuml_model.predict(X_test, - predict_model="GPU", - threshold=0.5, - fil_sparse_format=fil_sparse_format, - algo=algo) + fil_preds = cuml_model.predict( + X_test, + predict_model="GPU", + threshold=0.5, + fil_sparse_format=fil_sparse_format, + algo=algo, + ) else: - fil_preds = cuml_model.predict(X_test, - predict_model="GPU", - threshold=0.5, - fil_sparse_format=fil_sparse_format, - algo=algo) + fil_preds = cuml_model.predict( + X_test, + predict_model="GPU", + threshold=0.5, + fil_sparse_format=fil_sparse_format, + algo=algo, + ) fil_preds = np.reshape(fil_preds, np.shape(y_test)) fil_acc = accuracy_score(y_test, fil_preds) - np.testing.assert_almost_equal(fil_acc, - cuml_model.score(X_test, y_test)) + np.testing.assert_almost_equal( + fil_acc, cuml_model.score(X_test, y_test) + ) fil_model = cuml_model.convert_to_fil_model() @@ -559,21 +629,25 @@ def test_rf_classification_sparse(small_clf, datatype, assert X.shape[1] == tl_model.num_features if X.shape[0] < 500000: - sk_model = skrfc(n_estimators=50, - max_depth=40, - min_samples_split=2, - random_state=10) + sk_model = skrfc( + n_estimators=50, + max_depth=40, + min_samples_split=2, + random_state=10, + ) sk_model.fit(X_train, y_train) sk_preds = sk_model.predict(X_test) sk_acc = accuracy_score(y_test, sk_preds) assert fil_acc >= (sk_acc - 0.07) -@pytest.mark.parametrize('datatype', [np.float32]) -@pytest.mark.parametrize('fil_sparse_format', ['not_supported', True, - 'auto', False]) -@pytest.mark.parametrize('algo', ['auto', 'naive', 'tree_reorg', - 'batch_tree_reorg']) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize( + "fil_sparse_format", ["not_supported", True, "auto", False] +) +@pytest.mark.parametrize( + "algo", ["auto", "naive", "tree_reorg", "batch_tree_reorg"] +) def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): use_handle = True num_treees = 50 @@ -581,31 +655,48 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): X, y = special_reg X = X.astype(datatype) y = y.astype(datatype) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Create a handle for the cuml model handle, stream = get_handle(use_handle, n_streams=1) # Initialize and fit using cuML's random forest regression model - cuml_model = curfr(n_bins=16, split_criterion=2, - min_samples_leaf=2, random_state=123, n_streams=1, - n_estimators=num_treees, handle=handle, max_leaves=-1, - max_depth=40, accuracy_metric='mse') + cuml_model = curfr( + n_bins=16, + split_criterion=2, + min_samples_leaf=2, + random_state=123, + n_streams=1, + n_estimators=num_treees, + handle=handle, + max_leaves=-1, + max_depth=40, + accuracy_metric="mse", + ) cuml_model.fit(X_train, y_train) # predict using FIL - if ((not fil_sparse_format or algo == 'tree_reorg' or - algo == 'batch_tree_reorg') or - fil_sparse_format == 'not_supported'): + if ( + not fil_sparse_format + or algo == "tree_reorg" + or algo == "batch_tree_reorg" + ) or fil_sparse_format == "not_supported": with pytest.raises(ValueError): - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - fil_sparse_format=fil_sparse_format, - algo=algo) + fil_preds = cuml_model.predict( + X_test, + predict_model="GPU", + fil_sparse_format=fil_sparse_format, + algo=algo, + ) else: - fil_preds = cuml_model.predict(X_test, predict_model="GPU", - fil_sparse_format=fil_sparse_format, - algo=algo) + fil_preds = cuml_model.predict( + X_test, + predict_model="GPU", + fil_sparse_format=fil_sparse_format, + algo=algo, + ) fil_preds = np.reshape(fil_preds, np.shape(y_test)) fil_r2 = r2_score(y_test, fil_preds, convert_dtype=datatype) @@ -614,8 +705,9 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): with cuml.using_output_type("numpy"): fil_model_preds = fil_model.predict(X_test) fil_model_preds = np.reshape(fil_model_preds, np.shape(y_test)) - fil_model_r2 = r2_score(y_test, fil_model_preds, - convert_dtype=datatype) + fil_model_r2 = r2_score( + y_test, fil_model_preds, convert_dtype=datatype + ) assert fil_r2 == fil_model_r2 tl_model = cuml_model.convert_to_treelite_model() @@ -625,20 +717,24 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): # Initialize, fit and predict using # sklearn's random forest regression model if X.shape[0] < 1000: # mode != "stress": - sk_model = skrfr(n_estimators=50, max_depth=40, - min_samples_split=2, - random_state=10) + sk_model = skrfr( + n_estimators=50, + max_depth=40, + min_samples_split=2, + random_state=10, + ) sk_model.fit(X_train, y_train) sk_preds = sk_model.predict(X_test) sk_r2 = r2_score(y_test, sk_preds, convert_dtype=datatype) assert fil_r2 >= (sk_r2 - 0.07) -@pytest.mark.xfail(reason='Need rapidsai/rmm#415 to detect memleak robustly') +@pytest.mark.xfail(reason="Need rapidsai/rmm#415 to detect memleak robustly") @pytest.mark.memleak -@pytest.mark.parametrize('fil_sparse_format', [True, False, 'auto']) -@pytest.mark.parametrize('n_iter', [unit_param(5), quality_param(30), - stress_param(80)]) +@pytest.mark.parametrize("fil_sparse_format", [True, False, "auto"]) +@pytest.mark.parametrize( + "n_iter", [unit_param(5), quality_param(30), stress_param(80)] +) def test_rf_memory_leakage(small_clf, fil_sparse_format, n_iter): datatype = np.float32 use_handle = True @@ -646,8 +742,9 @@ def test_rf_memory_leakage(small_clf, fil_sparse_format, n_iter): X, y = small_clf X = X.astype(datatype) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Create a handle for the cuml model handle, stream = get_handle(use_handle, n_streams=1) @@ -668,8 +765,11 @@ def test_for_memory_leak(): assert delta_mem == 0 for i in range(2): - cuml_mods.predict(X_test, predict_model="GPU", - fil_sparse_format=fil_sparse_format) + cuml_mods.predict( + X_test, + predict_model="GPU", + fil_sparse_format=fil_sparse_format, + ) handle.sync() # just to be sure # Calculate the memory free after predicting the cuML model delta_mem = free_mem - cuda.current_context().get_memory_info()[0] @@ -679,39 +779,40 @@ def test_for_memory_leak(): test_for_memory_leak() -@pytest.mark.parametrize('max_features', [1.0, 'auto', 'log2', 'sqrt']) -@pytest.mark.parametrize('max_depth', [10, 13, 16]) -@pytest.mark.parametrize('n_estimators', [10, 20, 100]) -@pytest.mark.parametrize('n_bins', [8, 9, 10]) -def test_create_classification_model(max_features, - max_depth, n_estimators, n_bins): +@pytest.mark.parametrize("max_features", [1.0, "auto", "log2", "sqrt"]) +@pytest.mark.parametrize("max_depth", [10, 13, 16]) +@pytest.mark.parametrize("n_estimators", [10, 20, 100]) +@pytest.mark.parametrize("n_bins", [8, 9, 10]) +def test_create_classification_model( + max_features, max_depth, n_estimators, n_bins +): # random forest classification model - cuml_model = curfc(max_features=max_features, - n_bins=n_bins, - n_estimators=n_estimators, - max_depth=max_depth) + cuml_model = curfc( + max_features=max_features, + n_bins=n_bins, + n_estimators=n_estimators, + max_depth=max_depth, + ) params = cuml_model.get_params() cuml_model2 = curfc() cuml_model2.set_params(**params) verfiy_params = cuml_model2.get_params() - assert params['max_features'] == verfiy_params['max_features'] - assert params['max_depth'] == verfiy_params['max_depth'] - assert params['n_estimators'] == verfiy_params['n_estimators'] - assert params['n_bins'] == verfiy_params['n_bins'] + assert params["max_features"] == verfiy_params["max_features"] + assert params["max_depth"] == verfiy_params["max_depth"] + assert params["n_estimators"] == verfiy_params["n_estimators"] + assert params["n_bins"] == verfiy_params["n_bins"] -@pytest.mark.parametrize('n_estimators', [10, 20, 100]) -@pytest.mark.parametrize('n_bins', [8, 9, 10]) +@pytest.mark.parametrize("n_estimators", [10, 20, 100]) +@pytest.mark.parametrize("n_bins", [8, 9, 10]) def test_multiple_fits_classification(large_clf, n_estimators, n_bins): datatype = np.float32 X, y = large_clf X = X.astype(datatype) y = y.astype(np.int32) - cuml_model = curfc(n_bins=n_bins, - n_estimators=n_estimators, - max_depth=10) + cuml_model = curfc(n_bins=n_bins, n_estimators=n_estimators, max_depth=10) # Calling multiple fits cuml_model.fit(X, y) @@ -720,28 +821,35 @@ def test_multiple_fits_classification(large_clf, n_estimators, n_bins): # Check if params are still intact params = cuml_model.get_params() - assert params['n_estimators'] == n_estimators - assert params['n_bins'] == n_bins + assert params["n_estimators"] == n_estimators + assert params["n_bins"] == n_bins -@pytest.mark.parametrize('column_info', [unit_param([100, 50]), - quality_param([200, 100]), - stress_param([500, 350])]) -@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), - stress_param(500000)]) -@pytest.mark.parametrize('n_estimators', [10, 20, 100]) -@pytest.mark.parametrize('n_bins', [8, 9, 10]) +@pytest.mark.parametrize( + "column_info", + [ + unit_param([100, 50]), + quality_param([200, 100]), + stress_param([500, 350]), + ], +) +@pytest.mark.parametrize( + "nrows", [unit_param(500), quality_param(5000), stress_param(500000)] +) +@pytest.mark.parametrize("n_estimators", [10, 20, 100]) +@pytest.mark.parametrize("n_bins", [8, 9, 10]) def test_multiple_fits_regression(column_info, nrows, n_estimators, n_bins): datatype = np.float32 ncols, n_info = column_info - X, y = make_regression(n_samples=nrows, n_features=ncols, - n_informative=n_info, - random_state=123) + X, y = make_regression( + n_samples=nrows, + n_features=ncols, + n_informative=n_info, + random_state=123, + ) X = X.astype(datatype) y = y.astype(np.int32) - cuml_model = curfr(n_bins=n_bins, - n_estimators=n_estimators, - max_depth=10) + cuml_model = curfr(n_bins=n_bins, n_estimators=n_estimators, max_depth=10) # Calling multiple fits cuml_model.fit(X, y) @@ -752,17 +860,22 @@ def test_multiple_fits_regression(column_info, nrows, n_estimators, n_bins): # Check if params are still intact params = cuml_model.get_params() - assert params['n_estimators'] == n_estimators - assert params['n_bins'] == n_bins + assert params["n_estimators"] == n_estimators + assert params["n_bins"] == n_bins -@pytest.mark.parametrize('n_estimators', [5, 10, 20]) -@pytest.mark.parametrize('detailed_text', [True, False]) +@pytest.mark.parametrize("n_estimators", [5, 10, 20]) +@pytest.mark.parametrize("detailed_text", [True, False]) def test_rf_get_text(n_estimators, detailed_text): - X, y = make_classification(n_samples=500, n_features=10, - n_clusters_per_class=1, n_informative=5, - random_state=94929, n_classes=2) + X, y = make_classification( + n_samples=500, + n_features=10, + n_clusters_per_class=1, + n_informative=5, + random_state=94929, + n_classes=2, + ) X = X.astype(np.float32) y = y.astype(np.int32) @@ -771,11 +884,19 @@ def test_rf_get_text(n_estimators, detailed_text): handle, stream = get_handle(True, n_streams=1) # Initialize cuML Random Forest classification model - cuml_model = curfc(handle=handle, max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=0, split_criterion=0, - min_samples_leaf=2, random_state=23707, n_streams=1, - n_estimators=n_estimators, max_leaves=-1, - max_depth=16) + cuml_model = curfc( + handle=handle, + max_features=1.0, + max_samples=1.0, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=23707, + n_streams=1, + n_estimators=n_estimators, + max_leaves=-1, + max_depth=16, + ) # Train model on the data cuml_model.fit(X, y) @@ -786,39 +907,57 @@ def test_rf_get_text(n_estimators, detailed_text): text_output = cuml_model.get_summary_text() # Test 1: Output is non-zero - assert '' != text_output + assert "" != text_output # Count the number of trees printed tree_count = 0 - for line in text_output.split('\n'): - if line.strip().startswith('Tree #'): + for line in text_output.split("\n"): + if line.strip().startswith("Tree #"): tree_count += 1 # Test 2: Correct number of trees are printed assert n_estimators == tree_count -@pytest.mark.parametrize('max_depth', [1, 2, 3, 5, 10, 15, 20]) -@pytest.mark.parametrize('n_estimators', [5, 10, 20]) -@pytest.mark.parametrize('estimator_type', ['regression', 'classification']) +@pytest.mark.parametrize("max_depth", [1, 2, 3, 5, 10, 15, 20]) +@pytest.mark.parametrize("n_estimators", [5, 10, 20]) +@pytest.mark.parametrize("estimator_type", ["regression", "classification"]) def test_rf_get_json(estimator_type, max_depth, n_estimators): - X, y = make_classification(n_samples=350, n_features=20, - n_clusters_per_class=1, n_informative=10, - random_state=123, n_classes=2) + X, y = make_classification( + n_samples=350, + n_features=20, + n_clusters_per_class=1, + n_informative=10, + random_state=123, + n_classes=2, + ) X = X.astype(np.float32) - if estimator_type == 'classification': - cuml_model = curfc(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=0, split_criterion=0, - min_samples_leaf=2, random_state=23707, n_streams=1, - n_estimators=n_estimators, max_leaves=-1, - max_depth=max_depth) + if estimator_type == "classification": + cuml_model = curfc( + max_features=1.0, + max_samples=1.0, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=23707, + n_streams=1, + n_estimators=n_estimators, + max_leaves=-1, + max_depth=max_depth, + ) y = y.astype(np.int32) - elif estimator_type == 'regression': - cuml_model = curfr(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=0, - min_samples_leaf=2, random_state=23707, n_streams=1, - n_estimators=n_estimators, max_leaves=-1, - max_depth=max_depth) + elif estimator_type == "regression": + cuml_model = curfr( + max_features=1.0, + max_samples=1.0, + n_bins=16, + min_samples_leaf=2, + random_state=23707, + n_streams=1, + n_estimators=n_estimators, + max_leaves=-1, + max_depth=max_depth, + ) y = y.astype(np.float32) else: assert False @@ -830,7 +969,7 @@ def test_rf_get_json(estimator_type, max_depth, n_estimators): json_obj = json.loads(json_out) # Test 1: Output is non-zero - assert '' != json_out + assert "" != json_out # Test 2: JSON object contains correct number of trees assert isinstance(json_obj, list) @@ -838,16 +977,16 @@ def test_rf_get_json(estimator_type, max_depth, n_estimators): # Test 3: Traverse JSON trees and get the same predictions as cuML RF def predict_with_json_tree(tree, x): - if 'children' not in tree: - assert 'leaf_value' in tree - return tree['leaf_value'] - assert 'split_feature' in tree - assert 'split_threshold' in tree - assert 'yes' in tree - assert 'no' in tree - if x[tree['split_feature']] <= tree['split_threshold']: - return predict_with_json_tree(tree['children'][0], x) - return predict_with_json_tree(tree['children'][1], x) + if "children" not in tree: + assert "leaf_value" in tree + return tree["leaf_value"] + assert "split_feature" in tree + assert "split_threshold" in tree + assert "yes" in tree + assert "no" in tree + if x[tree["split_feature"]] <= tree["split_threshold"] + 1e-5: + return predict_with_json_tree(tree["children"][0], x) + return predict_with_json_tree(tree["children"][1], x) def predict_with_json_rf_classifier(rf, x): # Returns the class with the highest vote. If there is a tie, return @@ -861,37 +1000,52 @@ def predict_with_json_rf_classifier(rf, x): return majority_vote def predict_with_json_rf_regressor(rf, x): - pred = 0. + pred = 0.0 for tree in rf: pred += predict_with_json_tree(tree, x) return pred / len(rf) - if estimator_type == 'classification': + if estimator_type == "classification": expected_pred = cuml_model.predict(X).astype(np.int32) for idx, row in enumerate(X): majority_vote = predict_with_json_rf_classifier(json_obj, row) assert expected_pred[idx] in majority_vote - elif estimator_type == 'regression': + elif estimator_type == "regression": expected_pred = cuml_model.predict(X).astype(np.float32) pred = [] for idx, row in enumerate(X): pred.append(predict_with_json_rf_regressor(json_obj, row)) pred = np.array(pred, dtype=np.float32) + print(json_obj) + for i in range(len(pred)): + assert np.isclose(pred[i], expected_pred[i]), X[i, 19] np.testing.assert_almost_equal(pred, expected_pred, decimal=6) -@pytest.mark.parametrize('max_depth', [1, 2, 3, 5, 10, 15, 20]) -@pytest.mark.parametrize('n_estimators', [5, 10, 20]) +@pytest.mark.parametrize("max_depth", [1, 2, 3, 5, 10, 15, 20]) +@pytest.mark.parametrize("n_estimators", [5, 10, 20]) def test_rf_instance_count(max_depth, n_estimators): - X, y = make_classification(n_samples=350, n_features=20, - n_clusters_per_class=1, n_informative=10, - random_state=123, n_classes=2) + X, y = make_classification( + n_samples=350, + n_features=20, + n_clusters_per_class=1, + n_informative=10, + random_state=123, + n_classes=2, + ) X = X.astype(np.float32) - cuml_model = curfc(max_features=1.0, max_samples=1.0, - n_bins=16, split_algo=1, split_criterion=0, - min_samples_leaf=2, random_state=23707, n_streams=1, - n_estimators=n_estimators, max_leaves=-1, - max_depth=max_depth) + cuml_model = curfc( + max_features=1.0, + max_samples=1.0, + n_bins=16, + split_criterion=0, + min_samples_leaf=2, + random_state=23707, + n_streams=1, + n_estimators=n_estimators, + max_leaves=-1, + max_depth=max_depth, + ) y = y.astype(np.int32) # Train model on the data @@ -904,24 +1058,27 @@ def test_rf_instance_count(max_depth, n_estimators): # the instance counts of its children. Note that the instance count # is only available with the new backend. def check_instance_count_for_non_leaf(tree): - assert 'instance_count' in tree - if 'children' not in tree: + assert "instance_count" in tree + if "children" not in tree: return - assert 'instance_count' in tree['children'][0] - assert 'instance_count' in tree['children'][1] - assert (tree['instance_count'] - == tree['children'][0]['instance_count'] - + tree['children'][1]['instance_count']) - check_instance_count_for_non_leaf(tree['children'][0]) - check_instance_count_for_non_leaf(tree['children'][1]) + assert "instance_count" in tree["children"][0] + assert "instance_count" in tree["children"][1] + assert ( + tree["instance_count"] + == tree["children"][0]["instance_count"] + + tree["children"][1]["instance_count"] + ) + check_instance_count_for_non_leaf(tree["children"][0]) + check_instance_count_for_non_leaf(tree["children"][1]) + for tree in json_obj: check_instance_count_for_non_leaf(tree) # The root's count must be equal to the number of rows in the data - assert tree['instance_count'] == X.shape[0] + assert tree["instance_count"] == X.shape[0] @pytest.mark.memleak -@pytest.mark.parametrize('estimator_type', ['classification']) +@pytest.mark.parametrize("estimator_type", ["classification"]) def test_rf_host_memory_leak(large_clf, estimator_type): import gc import os @@ -935,16 +1092,12 @@ def test_rf_host_memory_leak(large_clf, estimator_type): X, y = large_clf X = X.astype(np.float32) - params = {'max_depth': 50} - if estimator_type == 'classification': - base_model = curfc(max_depth=10, - n_estimators=100, - random_state=123) + params = {"max_depth": 50} + if estimator_type == "classification": + base_model = curfc(max_depth=10, n_estimators=100, random_state=123) y = y.astype(np.int32) else: - base_model = curfr(max_depth=10, - n_estimators=100, - random_state=123) + base_model = curfr(max_depth=10, n_estimators=100, random_state=123) y = y.astype(np.float32) # Pre-fit once - this is our baseline and memory usage @@ -965,7 +1118,7 @@ def test_rf_host_memory_leak(large_clf, estimator_type): @pytest.mark.memleak -@pytest.mark.parametrize('estimator_type', ['regression', 'classification']) +@pytest.mark.parametrize("estimator_type", ["regression", "classification"]) def test_concat_memory_leak(large_clf, estimator_type): import gc import os @@ -982,15 +1135,17 @@ def test_concat_memory_leak(large_clf, estimator_type): # Build a series of RF models n_models = 10 - if estimator_type == 'classification': - base_models = [curfc(max_depth=10, - n_estimators=100, - random_state=123) for i in range(n_models)] + if estimator_type == "classification": + base_models = [ + curfc(max_depth=10, n_estimators=100, random_state=123) + for i in range(n_models) + ] y = y.astype(np.int32) - elif estimator_type == 'regression': - base_models = [curfr(max_depth=10, - n_estimators=100, - random_state=123) for i in range(n_models)] + elif estimator_type == "regression": + base_models = [ + curfr(max_depth=10, n_estimators=100, random_state=123) + for i in range(n_models) + ] y = y.astype(np.float32) else: assert False @@ -1014,13 +1169,16 @@ def test_concat_memory_leak(large_clf, estimator_type): init_model._concatenate_treelite_handle(other_handles) gc.collect() used_mem = process.memory_info().rss - logger.debug("memory at rep %2d: %d m" % ( - i, (used_mem - initial_baseline_mem)/1e6)) + logger.debug( + "memory at rep %2d: %d m" + % (i, (used_mem - initial_baseline_mem) / 1e6) + ) gc.collect() used_mem = process.memory_info().rss - logger.info("Final memory delta: %d" % ( - (used_mem - initial_baseline_mem)/1e6)) + logger.info( + "Final memory delta: %d" % ((used_mem - initial_baseline_mem) / 1e6) + ) assert (used_mem - initial_baseline_mem) < 1e6 @@ -1030,34 +1188,39 @@ def test_rf_nbins_small(small_clf): X, y = small_clf X = X.astype(np.float32) y = y.astype(np.int32) - X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, - random_state=0) + X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.8, random_state=0 + ) # Initialize, fit and predict using cuML's # random forest classification model cuml_model = curfc() cuml_model.fit(X_train[0:3, :], y_train[0:3]) -@pytest.mark.parametrize('split_criterion', [2], ids=['mse']) -@pytest.mark.parametrize('use_experimental_backend', [True, False]) -def test_rf_regression_with_identical_labels(split_criterion, - use_experimental_backend): +@pytest.mark.parametrize("split_criterion", [2], ids=["mse"]) +def test_rf_regression_with_identical_labels(split_criterion): X = np.array([[-1, 0], [0, 1], [2, 0], [0, 3], [-2, 0]], dtype=np.float32) y = np.array([1, 1, 1, 1, 1], dtype=np.float32) # Degenerate case: all labels are identical. # RF Regressor must not create any split. It must yield an empty tree # with only the root node. - clf = curfr(max_features=1.0, max_samples=1.0, n_bins=5, split_algo=1, - bootstrap=False, split_criterion=split_criterion, - min_samples_leaf=1, min_samples_split=2, random_state=0, - n_streams=1, n_estimators=1, max_depth=1, - use_experimental_backend=use_experimental_backend) + clf = curfr( + max_features=1.0, + max_samples=1.0, + n_bins=5, + bootstrap=False, + split_criterion=split_criterion, + min_samples_leaf=1, + min_samples_split=2, + random_state=0, + n_streams=1, + n_estimators=1, + max_depth=1, + ) clf.fit(X, y) model_dump = json.loads(clf.get_json()) assert len(model_dump) == 1 - expected_dump = {'nodeid': 0, 'leaf_value': 1.0} - if use_experimental_backend: - expected_dump['instance_count'] = 5 + expected_dump = {"nodeid": 0, "leaf_value": 1.0, "instance_count": 5} assert model_dump[0] == expected_dump