diff --git a/cpp/src/decisiontree/decisiontree.cu b/cpp/src/decisiontree/decisiontree.cu index 949da895c7..233148d73f 100644 --- a/cpp/src/decisiontree/decisiontree.cu +++ b/cpp/src/decisiontree/decisiontree.cu @@ -159,8 +159,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_classifier = std::make_shared>(); + std::unique_ptr> global_quantiles_buffer = + nullptr; + float *global_quantiles = nullptr; + + if (tree_params.use_experimental_backend) { + auto quantile_size = tree_params.n_bins * ncols; + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, + nrows, ncols, handle.get_device_allocator(), + handle.get_stream()); + } dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, tree, tree_params, seed); + unique_labels, tree, tree_params, seed, global_quantiles); } void decisionTreeClassifierFit(const raft::handle_t &handle, @@ -172,8 +185,21 @@ void decisionTreeClassifierFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_classifier = std::make_shared>(); + std::unique_ptr> global_quantiles_buffer = + nullptr; + double *global_quantiles = nullptr; + + if (tree_params.use_experimental_backend) { + auto quantile_size = tree_params.n_bins * ncols; + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, + nrows, ncols, handle.get_device_allocator(), + handle.get_stream()); + } dt_classifier->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, tree, tree_params, seed); + unique_labels, tree, tree_params, seed, global_quantiles); } void decisionTreeClassifierPredict(const raft::handle_t &handle, @@ -208,8 +234,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_regressor = std::make_shared>(); + std::unique_ptr> global_quantiles_buffer = + nullptr; + float *global_quantiles = nullptr; + + if (tree_params.use_experimental_backend) { + auto quantile_size = tree_params.n_bins * ncols; + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, + nrows, ncols, handle.get_device_allocator(), + handle.get_stream()); + } dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - tree, tree_params, seed); + tree, tree_params, seed, global_quantiles); } void decisionTreeRegressorFit(const raft::handle_t &handle, @@ -220,8 +259,21 @@ void decisionTreeRegressorFit(const raft::handle_t &handle, uint64_t seed) { std::shared_ptr> dt_regressor = std::make_shared>(); + std::unique_ptr> global_quantiles_buffer = + nullptr; + double *global_quantiles = nullptr; + + if (tree_params.use_experimental_backend) { + auto quantile_size = tree_params.n_bins * ncols; + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles(global_quantiles, tree_params.n_bins, data, + nrows, ncols, handle.get_device_allocator(), + handle.get_stream()); + } dt_regressor->fit(handle, data, ncols, nrows, labels, rowids, n_sampled_rows, - tree, tree_params, seed); + tree, tree_params, seed, global_quantiles); } void decisionTreeRegressorPredict(const raft::handle_t &handle, diff --git a/cpp/src/decisiontree/decisiontree_impl.cuh b/cpp/src/decisiontree/decisiontree_impl.cuh index ce0b23fab4..98d6e5300c 100644 --- a/cpp/src/decisiontree/decisiontree_impl.cuh +++ b/cpp/src/decisiontree/decisiontree_impl.cuh @@ -29,7 +29,6 @@ #include "levelalgo/levelfunc_regressor.cuh" #include "levelalgo/metric.cuh" #include "memory.cuh" -#include "quantile/quantile.cuh" #include "quantile/quantile.h" #include "treelite_util.h" @@ -293,20 +292,9 @@ void DecisionTreeBase::plant( total_temp_mem = tempmem->totalmem; MLCommon::TimerCPU timer; - if (tree_params.use_experimental_backend) { - if (treeid == 0) { - CUML_LOG_WARN("Using experimental backend for growing trees\n"); - } - T *quantiles = tempmem->d_quantile->data(); - grow_tree(tempmem->device_allocator, tempmem->host_allocator, data, treeid, - seed, ncols, nrows, labels, quantiles, (int *)rowids, - n_sampled_rows, unique_labels, tree_params, tempmem->stream, - sparsetree, this->leaf_counter, this->depth_counter); - } else { - grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols, - tree_params.max_features, dinfo.NLocalrows, sparsetree, - treeid, tempmem); - } + grow_deep_tree(data, labels, rowids, n_sampled_rows, ncols, + tree_params.max_features, dinfo.NLocalrows, sparsetree, treeid, + tempmem); train_time = timer.getElapsedSeconds(); ML::POP_RANGE(); } @@ -379,7 +367,7 @@ void DecisionTreeBase::base_fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, std::vector> &sparsetree, - const int treeid, uint64_t seed, bool is_classifier, + const int treeid, uint64_t seed, bool is_classifier, T *d_global_quantiles, std::shared_ptr> in_tempmem) { prepare_fit_timer.reset(); const char *CRITERION_NAME[] = {"GINI", "ENTROPY", "MSE", "MAE", "END"}; @@ -406,19 +394,37 @@ void DecisionTreeBase::base_fit( "Unsupported criterion %s\n", CRITERION_NAME[tree_params.split_criterion]); - if (in_tempmem != nullptr) { - tempmem = in_tempmem; - } else { - tempmem = std::make_shared>( - device_allocator_in, host_allocator_in, stream_in, nrows, ncols, - unique_labels, tree_params); - tree_params.quantile_per_tree = true; + if (!tree_params.use_experimental_backend) { + // Only execute for level backend as temporary memory is unused in batched + // backend. + if (in_tempmem != nullptr) { + tempmem = in_tempmem; + } else { + tempmem = std::make_shared>( + device_allocator_in, host_allocator_in, stream_in, nrows, ncols, + unique_labels, tree_params); + tree_params.quantile_per_tree = true; + } } - plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows, - unique_labels, treeid, seed); - if (in_tempmem == nullptr) { - tempmem.reset(); + if (tree_params.use_experimental_backend) { + dinfo.NLocalrows = nrows; + dinfo.NGlobalrows = nrows; + dinfo.Ncols = ncols; + n_unique_labels = unique_labels; + if (treeid == 0) { + CUML_LOG_WARN("Using experimental backend for growing trees\n"); + } + grow_tree(device_allocator_in, host_allocator_in, data, treeid, seed, ncols, + nrows, labels, d_global_quantiles, (int *)rowids, n_sampled_rows, + unique_labels, tree_params, stream_in, sparsetree, + this->leaf_counter, this->depth_counter); + } else { + plant(sparsetree, data, ncols, nrows, labels, rowids, n_sampled_rows, + unique_labels, treeid, seed); + if (in_tempmem == nullptr) { + tempmem.reset(); + } } } @@ -427,13 +433,13 @@ void DecisionTreeClassifier::fit( const raft::handle_t &handle, const T *data, const int ncols, const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, + DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles, std::shared_ptr> in_tempmem) { this->tree_params = tree_parameters; this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(), handle.get_stream(), data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, tree->sparsetree, tree->treeid, - seed, true, in_tempmem); + seed, true, d_global_quantiles, in_tempmem); this->set_metadata(tree); } @@ -445,12 +451,13 @@ void DecisionTreeClassifier::fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, + DecisionTreeParams tree_parameters, uint64_t seed, T *d_global_quantiles, std::shared_ptr> in_tempmem) { this->tree_params = tree_parameters; this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols, nrows, labels, rowids, n_sampled_rows, unique_labels, - tree->sparsetree, tree->treeid, seed, true, in_tempmem); + tree->sparsetree, tree->treeid, seed, true, d_global_quantiles, + in_tempmem); this->set_metadata(tree); } @@ -459,12 +466,13 @@ void DecisionTreeRegressor::fit( const raft::handle_t &handle, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, std::shared_ptr> in_tempmem) { + uint64_t seed, T *d_global_quantiles, + std::shared_ptr> in_tempmem) { this->tree_params = tree_parameters; this->base_fit(handle.get_device_allocator(), handle.get_host_allocator(), handle.get_stream(), data, ncols, nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree, tree->treeid, seed, false, - in_tempmem); + d_global_quantiles, in_tempmem); this->set_metadata(tree); } @@ -475,11 +483,12 @@ void DecisionTreeRegressor::fit( const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, std::shared_ptr> in_tempmem) { + uint64_t seed, T *d_global_quantiles, + std::shared_ptr> in_tempmem) { this->tree_params = tree_parameters; this->base_fit(device_allocator_in, host_allocator_in, stream_in, data, ncols, nrows, labels, rowids, n_sampled_rows, 1, tree->sparsetree, - tree->treeid, seed, false, in_tempmem); + tree->treeid, seed, false, d_global_quantiles, in_tempmem); this->set_metadata(tree); } diff --git a/cpp/src/decisiontree/decisiontree_impl.h b/cpp/src/decisiontree/decisiontree_impl.h index b39acc5088..b79463fede 100644 --- a/cpp/src/decisiontree/decisiontree_impl.h +++ b/cpp/src/decisiontree/decisiontree_impl.h @@ -106,7 +106,7 @@ class DecisionTreeBase { const int nrows, const L *labels, unsigned int *rowids, const int n_sampled_rows, int unique_labels, std::vector> &sparsetree, const int treeid, - uint64_t seed, bool is_classifier, + uint64_t seed, bool is_classifier, T *d_global_quantiles, std::shared_ptr> in_tempmem); public: @@ -140,7 +140,7 @@ class DecisionTreeClassifier : public DecisionTreeBase { const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, + uint64_t seed, T *d_quantiles, std::shared_ptr> in_tempmem = nullptr); //This fit fucntion does not take handle , used by RF @@ -150,7 +150,8 @@ class DecisionTreeClassifier : public DecisionTreeBase { const int nrows, const int *labels, unsigned int *rowids, const int n_sampled_rows, const int unique_labels, TreeMetaDataNode *&tree, DecisionTreeParams tree_parameters, - uint64_t seed, std::shared_ptr> in_tempmem); + uint64_t seed, T *d_quantiles, + std::shared_ptr> in_tempmem); private: void grow_deep_tree(const T *data, const int *labels, unsigned int *rowids, @@ -168,7 +169,7 @@ class DecisionTreeRegressor : public DecisionTreeBase { void fit(const raft::handle_t &handle, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, + DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles, std::shared_ptr> in_tempmem = nullptr); //This fit function does not take handle. Used by RF @@ -177,7 +178,7 @@ class DecisionTreeRegressor : public DecisionTreeBase { const cudaStream_t stream_in, const T *data, const int ncols, const int nrows, const T *labels, unsigned int *rowids, const int n_sampled_rows, TreeMetaDataNode *&tree, - DecisionTreeParams tree_parameters, uint64_t seed, + DecisionTreeParams tree_parameters, uint64_t seed, T *d_quantiles, std::shared_ptr> in_tempmem); private: diff --git a/cpp/src/decisiontree/quantile/quantile.cuh b/cpp/src/decisiontree/quantile/quantile.cuh index ae651074b6..75458dceac 100644 --- a/cpp/src/decisiontree/quantile/quantile.cuh +++ b/cpp/src/decisiontree/quantile/quantile.cuh @@ -15,8 +15,10 @@ */ #pragma once -#include #include +#include +#include +#include #include "quantile.h" #include @@ -24,6 +26,10 @@ namespace ML { namespace DecisionTree { +using device_allocator = raft::mr::device::allocator; +template +using device_buffer = raft::mr::device::buffer; + template __global__ void allcolsampler_kernel(const T *__restrict__ data, const unsigned int *__restrict__ rowids, @@ -183,5 +189,63 @@ void preprocess_quantile(const T *data, const unsigned int *rowids, return; } +template +__global__ void computeQuantilesSorted(T *quantiles, const int n_bins, + const T *sorted_data, const int length) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + double bin_width = static_cast(length) / n_bins; + int index = int(round((tid + 1) * bin_width)) - 1; + // Old way of computing quantiles. Kept here for comparison. + // To be deleted eventually + // int index = (tid + 1) * floor(bin_width) - 1; + if (tid < n_bins) { + quantiles[tid] = sorted_data[index]; + } + + return; +} + +template +void computeQuantiles(T *quantiles, int n_bins, const T *data, int n_rows, + int n_cols, + const std::shared_ptr device_allocator, + cudaStream_t stream) { + // Determine temporary device storage requirements + std::unique_ptr> d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + + std::unique_ptr> single_column_sorted = nullptr; + single_column_sorted = + std::make_unique>(device_allocator, stream, n_rows); + + CUDA_CHECK(cub::DeviceRadixSort::SortKeys(nullptr, temp_storage_bytes, data, + single_column_sorted->data(), + n_rows, 0, 8 * sizeof(T), stream)); + + // Allocate temporary storage for sorting + d_temp_storage = std::make_unique>( + device_allocator, stream, temp_storage_bytes); + + // Compute quantiles column by column + for (int col = 0; col < n_cols; col++) { + int col_offset = col * n_rows; + int quantile_offset = col * n_bins; + + CUDA_CHECK(cub::DeviceRadixSort::SortKeys( + (void *)d_temp_storage->data(), temp_storage_bytes, &data[col_offset], + single_column_sorted->data(), n_rows, 0, 8 * sizeof(T), stream)); + + int blocks = raft::ceildiv(n_bins, 128); + + computeQuantilesSorted<<>>( + &quantiles[quantile_offset], n_bins, single_column_sorted->data(), + n_rows); + + CUDA_CHECK(cudaGetLastError()); + } + + return; +} + } // namespace DecisionTree } // namespace ML diff --git a/cpp/src/decisiontree/quantile/quantile.h b/cpp/src/decisiontree/quantile/quantile.h index 91d953288f..a7b5933969 100644 --- a/cpp/src/decisiontree/quantile/quantile.h +++ b/cpp/src/decisiontree/quantile/quantile.h @@ -17,9 +17,11 @@ #pragma once #include +#include template struct TemporaryMemory; +using deviceAllocator = raft::mr::device::allocator; namespace ML { namespace DecisionTree { @@ -30,5 +32,11 @@ void preprocess_quantile(const T *data, const unsigned int *rowids, const int rowoffset, const int nbins, std::shared_ptr> tempmem); +template +void computeQuantiles(T *quantiles, int n_bins, const T *data, int n_rows, + int n_cols, + const std::shared_ptr device_allocator, + cudaStream_t stream); + } // namespace DecisionTree } // namespace ML diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index 091996f2cb..6c2bfb5f5a 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -18,10 +18,10 @@ #define omp_get_thread_num() 0 #endif #include -#include #include #include #include +#include #include #include #include @@ -179,7 +179,6 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, "rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%d)", n_streams, handle.get_num_internal_streams()); - cudaStream_t stream = handle.get_stream(); // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device ptr. MLCommon::device_buffer* selected_rows[n_streams]; @@ -190,37 +189,63 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, } std::shared_ptr> tempmem[n_streams]; - for (int i = 0; i < n_streams; i++) { - tempmem[i] = std::make_shared>( - handle, handle.get_internal_stream(i), n_rows, n_cols, n_unique_labels, - this->rf_params.tree_params); + if (this->rf_params.tree_params.use_experimental_backend) { + // TemporaryMemory is unused for batched (new) backend + for (int i = 0; i < n_streams; i++) { + tempmem[i] = nullptr; + } + } else { + // Allocate TemporaryMemory for each stream + for (int i = 0; i < n_streams; i++) { + tempmem[i] = std::make_shared>( + handle, handle.get_internal_stream(i), n_rows, n_cols, n_unique_labels, + this->rf_params.tree_params); + } } + + std::unique_ptr> global_quantiles_buffer = nullptr; + T* global_quantiles = nullptr; + auto quantile_size = this->rf_params.tree_params.n_bins * n_cols; + //Preprocess once only per forest - if ((this->rf_params.tree_params.split_algo == SPLIT_ALGO::GLOBAL_QUANTILE) && - !(this->rf_params.tree_params.quantile_per_tree)) { - DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, - this->rf_params.tree_params.n_bins, - tempmem[0]); - for (int i = 1; i < n_streams; i++) { - CUDA_CHECK(cudaMemcpyAsync( - tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T), - cudaMemcpyDeviceToDevice, tempmem[i]->stream)); - memcpy((void*)(tempmem[i]->h_quantile->data()), - (void*)(tempmem[0]->h_quantile->data()), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); + if (this->rf_params.tree_params.use_experimental_backend) { + // Using batched backend + // allocate space for d_global_quantiles + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles( + global_quantiles, this->rf_params.tree_params.n_bins, input, n_rows, + n_cols, handle.get_device_allocator(), handle.get_stream()); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); + } else { + if ((this->rf_params.tree_params.split_algo == + SPLIT_ALGO::GLOBAL_QUANTILE) && + !(this->rf_params.tree_params.quantile_per_tree)) { + // Using level (old) backend + DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, + this->rf_params.tree_params.n_bins, + tempmem[0]); + for (int i = 1; i < n_streams; i++) { + CUDA_CHECK(cudaMemcpyAsync( + tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), + this->rf_params.tree_params.n_bins * n_cols * sizeof(T), + cudaMemcpyDeviceToDevice, tempmem[i]->stream)); + memcpy((void*)(tempmem[i]->h_quantile->data()), + (void*)(tempmem[0]->h_quantile->data()), + this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); + } } } #pragma omp parallel for num_threads(n_streams) for (int i = 0; i < this->rf_params.n_trees; i++) { int stream_id = omp_get_thread_num(); - unsigned int* rowids; - rowids = selected_rows[stream_id]->data(); + unsigned int* rowids = selected_rows[stream_id]->data(); this->prepare_fit_per_tree( - i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, handle.get_device_allocator()); + i, n_rows, n_sampled_rows, rowids, raft::getMultiProcessorCount(), + handle.get_internal_stream(stream_id), handle.get_device_allocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -233,22 +258,22 @@ void rfClassifier::fit(const raft::handle_t& user_handle, const T* input, DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; trees[i].fit(handle.get_device_allocator(), handle.get_host_allocator(), - tempmem[stream_id]->stream, input, n_cols, n_rows, labels, - rowids, n_sampled_rows, n_unique_labels, tree_ptr, + handle.get_internal_stream(stream_id), input, n_cols, n_rows, + labels, rowids, n_sampled_rows, n_unique_labels, tree_ptr, this->rf_params.tree_params, this->rf_params.seed, - tempmem[stream_id]); + global_quantiles, tempmem[stream_id]); } //Cleanup for (int i = 0; i < n_streams; i++) { - auto s = tempmem[i]->stream; + auto s = handle.get_internal_stream(i); CUDA_CHECK(cudaStreamSynchronize(s)); selected_rows[i]->release(s); - tempmem[i].reset(); delete selected_rows[i]; + if (!this->rf_params.tree_params.use_experimental_backend) { + tempmem[i].reset(); + } } - - CUDA_CHECK(cudaStreamSynchronize(user_handle.get_stream())); - + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); ML::POP_RANGE(); } @@ -446,23 +471,22 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, const raft::handle_t& handle = user_handle; int n_sampled_rows = 0; if (this->rf_params.bootstrap) { - n_sampled_rows = this->rf_params.max_samples * n_rows; + n_sampled_rows = std::round(this->rf_params.max_samples * n_rows); } else { if (this->rf_params.max_samples != 1.0) { CUML_LOG_WARN( "If bootstrap sampling is disabled, max_samples value is ignored and " "whole dataset is used for building each tree"); + this->rf_params.max_samples = 1.0; } n_sampled_rows = n_rows; } - int n_streams = this->rf_params.n_streams; ASSERT( n_streams <= handle.get_num_internal_streams(), "rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%d)", n_streams, handle.get_num_internal_streams()); - cudaStream_t stream = user_handle.get_stream(); // Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree. // selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device ptr. MLCommon::device_buffer* selected_rows[n_streams]; @@ -473,25 +497,52 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, } std::shared_ptr> tempmem[n_streams]; - for (int i = 0; i < n_streams; i++) { - tempmem[i] = std::make_shared>( - handle, handle.get_internal_stream(i), n_rows, n_cols, 1, - this->rf_params.tree_params); + if (this->rf_params.tree_params.use_experimental_backend) { + // TemporaryMemory is unused for batched (new) backend + for (int i = 0; i < n_streams; i++) { + tempmem[i] = nullptr; + } + } else { + // Allocate TemporaryMemory for each stream + for (int i = 0; i < n_streams; i++) { + tempmem[i] = std::make_shared>( + handle, handle.get_internal_stream(i), n_rows, n_cols, 1, + this->rf_params.tree_params); + } } + + std::unique_ptr> global_quantiles_buffer = nullptr; + T* global_quantiles = nullptr; + auto quantile_size = this->rf_params.tree_params.n_bins * n_cols; + //Preprocess once only per forest - if ((this->rf_params.tree_params.split_algo == SPLIT_ALGO::GLOBAL_QUANTILE) && - !(this->rf_params.tree_params.quantile_per_tree)) { - DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, - this->rf_params.tree_params.n_bins, - tempmem[0]); - for (int i = 1; i < n_streams; i++) { - CUDA_CHECK(cudaMemcpyAsync( - tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T), - cudaMemcpyDeviceToDevice, tempmem[i]->stream)); - memcpy((void*)(tempmem[i]->h_quantile->data()), - (void*)(tempmem[0]->h_quantile->data()), - this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); + if (this->rf_params.tree_params.use_experimental_backend) { + // Using batched backend + // allocate space for d_global_quantiles + global_quantiles_buffer = std::make_unique>( + handle.get_device_allocator(), handle.get_stream(), quantile_size); + global_quantiles = global_quantiles_buffer->data(); + DecisionTree::computeQuantiles( + global_quantiles, this->rf_params.tree_params.n_bins, input, n_rows, + n_cols, handle.get_device_allocator(), handle.get_stream()); + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); + } else { + if ((this->rf_params.tree_params.split_algo == + SPLIT_ALGO::GLOBAL_QUANTILE) && + !(this->rf_params.tree_params.quantile_per_tree)) { + // Using level (old) backend + DecisionTree::preprocess_quantile(input, nullptr, n_rows, n_cols, n_rows, + this->rf_params.tree_params.n_bins, + tempmem[0]); + for (int i = 1; i < n_streams; i++) { + CUDA_CHECK(cudaMemcpyAsync( + tempmem[i]->d_quantile->data(), tempmem[0]->d_quantile->data(), + this->rf_params.tree_params.n_bins * n_cols * sizeof(T), + cudaMemcpyDeviceToDevice, tempmem[i]->stream)); + memcpy((void*)(tempmem[i]->h_quantile->data()), + (void*)(tempmem[0]->h_quantile->data()), + this->rf_params.tree_params.n_bins * n_cols * sizeof(T)); + } } } @@ -499,9 +550,10 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, for (int i = 0; i < this->rf_params.n_trees; i++) { int stream_id = omp_get_thread_num(); unsigned int* rowids = selected_rows[stream_id]->data(); + this->prepare_fit_per_tree( - i, n_rows, n_sampled_rows, rowids, tempmem[stream_id]->num_sms, - tempmem[stream_id]->stream, handle.get_device_allocator()); + i, n_rows, n_sampled_rows, rowids, raft::getMultiProcessorCount(), + handle.get_internal_stream(stream_id), handle.get_device_allocator()); /* Build individual tree in the forest. - input is a pointer to orig data that have n_cols features and n_rows rows. @@ -513,21 +565,22 @@ void rfRegressor::fit(const raft::handle_t& user_handle, const T* input, DecisionTree::TreeMetaDataNode* tree_ptr = &(forest->trees[i]); tree_ptr->treeid = i; trees[i].fit(handle.get_device_allocator(), handle.get_host_allocator(), - tempmem[stream_id]->stream, input, n_cols, n_rows, labels, - rowids, n_sampled_rows, tree_ptr, this->rf_params.tree_params, - this->rf_params.seed, tempmem[stream_id]); + handle.get_internal_stream(stream_id), input, n_cols, n_rows, + labels, rowids, n_sampled_rows, tree_ptr, + this->rf_params.tree_params, this->rf_params.seed, + global_quantiles, tempmem[stream_id]); } //Cleanup for (int i = 0; i < n_streams; i++) { - auto s = tempmem[i]->stream; + auto s = handle.get_internal_stream(i); CUDA_CHECK(cudaStreamSynchronize(s)); selected_rows[i]->release(s); - tempmem[i].reset(); delete selected_rows[i]; + if (!this->rf_params.tree_params.use_experimental_backend) { + tempmem[i].reset(); + } } - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - ML::POP_RANGE(); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0376c28416..4a01c31d9c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -67,6 +67,7 @@ if(BUILD_CUML_TESTS) sg/rf_batched_classification_test.cu sg/rf_batched_regression_test.cu sg/rf_depth_test.cu + sg/rf_quantiles_test.cu sg/rf_test.cu sg/rf_treelite_test.cu sg/ridge.cu diff --git a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu index 3bd5e0a58a..e3ed1db140 100644 --- a/cpp/test/sg/decisiontree_batchedlevel_unittest.cu +++ b/cpp/test/sg/decisiontree_batchedlevel_unittest.cu @@ -15,12 +15,12 @@ */ #include +#include #include #include #include #include #include -#include #include namespace ML { diff --git a/cpp/test/sg/rf_quantiles_test.cu b/cpp/test/sg/rf_quantiles_test.cu new file mode 100644 index 0000000000..4e12702a9d --- /dev/null +++ b/cpp/test/sg/rf_quantiles_test.cu @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ML { + +using namespace MLCommon; + +// N O T E +// If the generated data has duplicate values at the quantile boundary, the +// test will fail. Probability of such a occurrence is low but should that +// happen, change the seed to workaround the issue. + +struct inputs { + int n_rows; + int n_bins; + uint64_t seed; +}; + +// Generate data with some outliers +template +__global__ void generateData(T* data, int length, uint64_t seed) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + raft::random::detail::Kiss99Generator gen(seed, 0, 0); + + T num = static_cast(0.0); + uint32_t toss, multiplier; + + for (int i = tid; i < length; i += blockDim.x * gridDim.x) { + // Generate a number + gen.next(multiplier); + multiplier &= 0xFF; + + gen.next(toss); + toss &= 0xFF; + + gen.next(num); + // Generate 5% outliers + // value of toss is in [0, 255], 5 % of that is 13 + if (toss < 13) { + // Number between [-multiplier, +multiplier] + data[i] = multiplier * (1 - 2 * num); + } else { + // Number between [-1, 1] + data[i] = (1 - 2 * num); + } + } + return; +} + +template +__global__ void computeHistogram(int* histogram, T* data, int length, + T* quantiles, int n_bins) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < length; i += blockDim.x * gridDim.x) { + T num = data[i]; + for (int j = 0; j < n_bins; j++) { + if (num <= quantiles[j]) { + atomicAdd(&histogram[j], 1); + break; + } + } + } + return; +} + +template +class RFQuantileTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + params = ::testing::TestWithParam::GetParam(); + + CUDA_CHECK(cudaStreamCreate(&stream)); + handle.reset(new raft::handle_t()); + handle->set_stream(stream); + auto allocator = handle->get_device_allocator(); + auto h_allocator = handle->get_host_allocator(); + + data = (T*)allocator->allocate(params.n_rows * sizeof(T), stream); + quantiles = (T*)allocator->allocate(params.n_bins * sizeof(T), stream); + histogram = (int*)allocator->allocate(params.n_bins * sizeof(int), stream); + h_histogram = + (int*)h_allocator->allocate(params.n_bins * sizeof(int), stream); + + CUDA_CHECK(cudaMemset(histogram, 0, params.n_bins * sizeof(int))); + const int TPB = 128; + int numBlocks = raft::ceildiv(params.n_rows, TPB); + generateData<<>>(data, params.n_rows, + params.seed); + DecisionTree::computeQuantiles(quantiles, params.n_bins, data, + params.n_rows, 1, allocator, stream); + + computeHistogram<<>>( + histogram, data, params.n_rows, quantiles, params.n_bins); + + CUDA_CHECK(cudaMemcpyAsync(h_histogram, histogram, + params.n_bins * sizeof(int), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + void TearDown() override { + auto allocator = handle->get_device_allocator(); + auto h_allocator = handle->get_host_allocator(); + + allocator->deallocate(data, params.n_rows * sizeof(T), stream); + allocator->deallocate(quantiles, params.n_bins * sizeof(T), stream); + allocator->deallocate(histogram, params.n_bins * sizeof(int), stream); + h_allocator->deallocate(h_histogram, params.n_bins * sizeof(int), stream); + + handle.reset(); + CUDA_CHECK(cudaStreamDestroy(stream)); + } + + void test_histogram() { + int max_items_per_bin = raft::ceildiv(params.n_rows, params.n_bins); + int min_items_per_bin = max_items_per_bin - 1; + int total_items = 0; + for (int b = 0; b < params.n_bins; b++) { + ASSERT_TRUE(h_histogram[b] == max_items_per_bin || + h_histogram[b] == min_items_per_bin) + << "No. samples in bin[" << b << "] = " << h_histogram[b] + << " Expected " << max_items_per_bin << " or " << min_items_per_bin + << std::endl; + total_items += h_histogram[b]; + } + ASSERT_EQ(params.n_rows, total_items) + << "Some samples from dataset are either missed of double counted in " + "quantile bins" + << std::endl; + } + + protected: + std::shared_ptr handle; + cudaStream_t stream; + inputs params; + + T *data, *quantiles; + bool result; + int *histogram, *h_histogram; +}; + +//------------------------------------------------------------------------------------------------------------------------------------- +const std::vector inputs = {{1000, 16, 6078587519764079670LLU}, + {1130, 32, 4884670006177930266LLU}, + {1752, 67, 9175325892580481371LLU}, + {2307, 99, 9507819643927052255LLU}, + {5000, 128, 9507819643927052255LLU}}; + +typedef RFQuantileTest RFQuantileTestF; +TEST_P(RFQuantileTestF, test) { test_histogram(); } + +INSTANTIATE_TEST_CASE_P(RFQuantileTests, RFQuantileTestF, + ::testing::ValuesIn(inputs)); + +typedef RFQuantileTest RFQuantileTestD; +TEST_P(RFQuantileTestD, test) { test_histogram(); } + +INSTANTIATE_TEST_CASE_P(RFQuantileTests, RFQuantileTestD, + ::testing::ValuesIn(inputs)); + +} // end namespace ML