From a3ea64dda849b9d6e59cdc9bc62b2fbf5c87e651 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Sat, 2 Apr 2022 19:52:00 +0200 Subject: [PATCH] float64 support in FIL functions (#4655) Templatized functions related to FIL inference in preparation of `float64` support. Instantiations of templates with `float64`, or tests for `float64`, _are not included_; they will be provided in a future pull request. This is pull request 2 of 3 to integrate https://github.com/rapidsai/cuml/pull/4646. This pull request is partly based on the work by @levsnv. Authors: - Andy Adinets (https://github.com/canonizer) - Levs Dolgovs (https://github.com/levsnv) - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - Divye Gala (https://github.com/divyegala) - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/4655 --- cpp/bench/sg/fil.cu | 2 +- cpp/include/cuml/fil/fil.h | 19 ++- cpp/src/fil/common.cuh | 37 +++-- cpp/src/fil/fil.cu | 216 +++++++++++++------------- cpp/src/fil/infer.cu | 225 +++++++++++++++------------- cpp/src/fil/internal.cuh | 30 ++-- cpp/src/fil/treelite_import.cu | 22 ++- cpp/test/sg/fil_child_index_test.cu | 2 +- cpp/test/sg/fil_test.cu | 24 +-- cpp/test/sg/rf_test.cu | 6 +- python/cuml/fil/fil.pyx | 32 ++-- 11 files changed, 335 insertions(+), 280 deletions(-) diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index 8128276d3e..adf283fbaf 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -117,7 +117,7 @@ class FIL : public RegressionFixture { } private: - ML::fil::forest_t forest; + ML::fil::forest_t forest; ModelHandle model; Params p_rest; }; diff --git a/cpp/include/cuml/fil/fil.h b/cpp/include/cuml/fil/fil.h index dfa66ad1a8..581fe3eb13 100644 --- a/cpp/include/cuml/fil/fil.h +++ b/cpp/include/cuml/fil/fil.h @@ -69,10 +69,12 @@ enum storage_type_t { }; static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"}; +template struct forest; /** forest_t is the predictor handle */ -typedef forest* forest_t; +template +using forest_t = forest*; /** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */ constexpr int MAX_N_ITEMS = 4; @@ -112,8 +114,9 @@ struct treelite_params_t { * @param model treelite model used to initialize the forest * @param tl_params additional parameters for the forest */ +// TODO (canonizer): use std::variant forest_t>* for pforest void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_t* pforest, ModelHandle model, const treelite_params_t* tl_params); @@ -121,24 +124,26 @@ void from_treelite(const raft::handle_t& handle, * @param h cuML handle used by this function * @param f the forest to free; not usable after the call to this function */ -void free(const raft::handle_t& h, forest_t f); +template +void free(const raft::handle_t& h, forest_t f); /** predict predicts on data (with n rows) using forest and writes results into preds; * the number of columns is stored in forest, and both preds and data point to GPU memory * @param h cuML handle used by this function * @param f forest used for predictions * @param preds array in GPU memory to store predictions into - size == predict_proba ? (2*num_rows) : num_rows + * size = predict_proba ? (2*num_rows) : num_rows * @param data array of size n * cols (cols is the number of columns * for the forest f) from which to predict * @param num_rows number of data rows * @param predict_proba for classifier models, this forces to output both class probabilities * instead of binary class prediction. format matches scikit-learn API */ +template void predict(const raft::handle_t& h, - forest_t f, - float* preds, - const float* data, + forest_t f, + real_t* preds, + const real_t* data, size_t num_rows, bool predict_proba = false); diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 99c0e7afe8..c4877216a8 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -48,6 +48,7 @@ struct storage_base { /** represents a dense tree */ template struct tree> : tree_base { + using real_type = real_t; __host__ __device__ tree(categorical_sets cat_sets, dense_node* nodes, int node_pitch) : tree_base{cat_sets}, nodes_(nodes), node_pitch_(node_pitch) { @@ -61,10 +62,10 @@ struct tree> : tree_base { }; /** partial specialization of storage. Stores the forest on GPU as a collection of dense nodes */ -template -struct storage> : storage_base { - using real_t = real_t_; - using node_t = dense_node; +template +struct storage> : storage_base { + using real_type = real_t; + using node_t = dense_node; __host__ __device__ storage(categorical_sets cat_sets, real_t* vector_leaf, node_t* nodes, @@ -93,6 +94,7 @@ struct storage> : storage_base { /** sparse tree */ template struct tree : tree_base { + using real_type = typename node_t::real_type; __host__ __device__ tree(categorical_sets cat_sets, node_t* nodes) : tree_base{cat_sets}, nodes_(nodes) { @@ -103,15 +105,15 @@ struct tree : tree_base { /** storage stores the forest on GPU as a collection of sparse nodes */ template -struct storage : storage_base { - using node_t = node_t_; - using real_t = typename node_t::real_t; - int* trees_ = nullptr; - node_t* nodes_ = nullptr; - int num_trees_ = 0; - __host__ __device__ - storage(categorical_sets cat_sets, real_t* vector_leaf, int* trees, node_t* nodes, int num_trees) - : storage_base{cat_sets, vector_leaf}, +struct storage : storage_base { + using node_t = node_t_; + using real_type = typename node_t::real_type; + int* trees_ = nullptr; + node_t* nodes_ = nullptr; + int num_trees_ = 0; + __host__ __device__ storage( + categorical_sets cat_sets, real_type* vector_leaf, int* trees, node_t* nodes, int num_trees) + : storage_base{cat_sets, vector_leaf}, trees_(trees), nodes_(nodes), num_trees_(num_trees) @@ -125,8 +127,11 @@ struct storage : storage_base { } }; -typedef storage> sparse_storage16; -typedef storage sparse_storage8; +using dense_storage_f32 = storage>; +using dense_storage_f64 = storage>; +using sparse_storage16_f32 = storage>; +using sparse_storage16_f64 = storage>; +using sparse_storage8 = storage; /// all model parameters mostly required to compute shared memory footprint, /// also the footprint itself @@ -168,7 +173,7 @@ struct shmem_size_params { { return cols_in_shmem ? sizeof_real * sdata_stride() * n_items << log2_threads_per_tree : 0; } - template + template size_t get_smem_footprint(); }; diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 705cd96702..04cd227bac 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -35,35 +35,40 @@ creation and prediction (the main inference kernel is defined in infer.cu). */ namespace ML { namespace fil { -__host__ __device__ float sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } +template +__host__ __device__ real_t sigmoid(real_t x) +{ + return real_t(1) / (real_t(1) + exp(-x)); +} /** performs additional transformations on the array of forest predictions (preds) of size n; the transformations are defined by output, and include averaging (multiplying by inv_num_trees), adding global_bias (always done), sigmoid and applying threshold. in case of complement_proba, fills in the complement probability */ -__global__ void transform_k(float* preds, +template +__global__ void transform_k(real_t* preds, size_t n, output_t output, - float inv_num_trees, - float threshold, - float global_bias, + real_t inv_num_trees, + real_t threshold, + real_t global_bias, bool complement_proba) { size_t i = threadIdx.x + size_t(blockIdx.x) * blockDim.x; if (i >= n) return; if (complement_proba && i % 2 != 0) return; - float result = preds[i]; + real_t result = preds[i]; if ((output & output_t::AVG) != 0) result *= inv_num_trees; result += global_bias; if ((output & output_t::SIGMOID) != 0) result = sigmoid(result); // will not be done on CATEGORICAL_LEAF because the whole kernel will not run - if ((output & output_t::CLASS) != 0) { result = result > threshold ? 1.0f : 0.0f; } + if ((output & output_t::CLASS) != 0) { result = result > threshold ? real_t(1) : real_t(0); } // sklearn outputs numpy array in 'C' order, with the number of classes being last dimension // that is also the default order, so we should use the same one if (complement_proba) { - preds[i] = 1.0f - result; + preds[i] = real_t(1) - result; preds[i + 1] = result; } else preds[i] = result; @@ -74,8 +79,11 @@ __global__ void transform_k(float* preds, // but rather one symbol for the whole template specialization, as below. extern template int dispatch_on_fil_template_params(compute_smem_footprint, predict_params); +// forest is the base type for all forests and contains data and methods common +// to both dense and sparse forests +template struct forest { - forest(const raft::handle_t& h) : vector_leaf_(0, h.get_stream()), cat_sets_(h.get_stream()) {} + forest(const raft::handle_t& h) : cat_sets_(h.get_stream()), vector_leaf_(0, h.get_stream()) {} void init_shmem_size(int device) { @@ -129,7 +137,6 @@ struct forest { fixed_block_count_ = blocks_per_sm * sm_count; } - template void init_common(const raft::handle_t& h, const categorical_sets& cat_sets, const std::vector& vector_leaf, @@ -139,8 +146,8 @@ struct forest { num_trees_ = params->num_trees; algo_ = params->algo; output_ = params->output; - threshold_ = params->threshold; - global_bias_ = params->global_bias; + threshold_ = static_cast(params->threshold); + global_bias_ = static_cast(params->global_bias); proba_ssp_.n_items = params->n_items; proba_ssp_.log2_threads_per_tree = log2(params->threads_per_tree); proba_ssp_.leaf_algo = params->leaf_algo; @@ -174,7 +181,7 @@ struct forest { virtual void infer(predict_params params, cudaStream_t stream) = 0; void predict( - const raft::handle_t& h, float* preds, const float* data, size_t num_rows, bool predict_proba) + const raft::handle_t& h, real_t* preds, const real_t* data, size_t num_rows, bool predict_proba) { // Initialize prediction parameters. predict_params params(predict_proba ? proba_ssp_ : class_ssp_); @@ -253,7 +260,7 @@ struct forest { // Simulating treelite order, which cancels out bias. // If non-proba prediction used, it still will not matter // for the same reason softmax will not. - float global_bias = (ot & output_t::SOFTMAX) != 0 ? 0.0f : global_bias_; + real_t global_bias = (ot & output_t::SOFTMAX) != 0 ? real_t(0) : global_bias_; bool complement_proba = false, do_transform; if (predict_proba) { @@ -270,23 +277,25 @@ struct forest { // for GROVE_PER_CLASS, averaging happens in infer_k ot = output_t(ot & ~output_t::AVG); params.num_outputs = params.num_classes; - do_transform = (ot != output_t::RAW && ot != output_t::SOFTMAX) || global_bias != 0.0f; + do_transform = + (ot != output_t::RAW && ot != output_t::SOFTMAX) || global_bias != real_t(0); break; case leaf_algo_t::CATEGORICAL_LEAF: params.num_outputs = params.num_classes; - do_transform = ot != output_t::RAW || global_bias_ != 0.0f; + do_transform = ot != output_t::RAW || global_bias_ != real_t(0); break; case leaf_algo_t::VECTOR_LEAF: // for VECTOR_LEAF, averaging happens in infer_k ot = output_t(ot & ~output_t::AVG); params.num_outputs = params.num_classes; - do_transform = (ot != output_t::RAW && ot != output_t::SOFTMAX) || global_bias != 0.0f; + do_transform = + (ot != output_t::RAW && ot != output_t::SOFTMAX) || global_bias != real_t(0); break; default: ASSERT(false, "internal error: predict: invalid leaf_algo %d", params.leaf_algo); } } else { if (params.leaf_algo == leaf_algo_t::FLOAT_UNARY_BINARY) { - do_transform = ot != output_t::RAW || global_bias_ != 0.0f; + do_transform = ot != output_t::RAW || global_bias_ != real_t(0); } else { // GROVE_PER_CLASS, CATEGORICAL_LEAF: moot since choosing best class and // all transforms are monotonic. also, would break current code @@ -305,7 +314,7 @@ struct forest { preds, num_values_to_transform, ot, - num_trees_ > 0 ? (1.0f / num_trees_) : 1.0f, + num_trees_ > 0 ? (real_t(1) / num_trees_) : real_t(1), threshold_, global_bias, complement_proba); @@ -321,18 +330,19 @@ struct forest { virtual ~forest() {} - int num_trees_ = 0; - int depth_ = 0; - algo_t algo_ = algo_t::NAIVE; - output_t output_ = output_t::RAW; - float threshold_ = 0.5; - float global_bias_ = 0; - shmem_size_params class_ssp_, proba_ssp_; + int num_trees_ = 0; + int depth_ = 0; + algo_t algo_ = algo_t::NAIVE; + output_t output_ = output_t::RAW; int fixed_block_count_ = 0; int max_shm_ = 0; + real_t threshold_ = 0.5; + real_t global_bias_ = 0; + shmem_size_params class_ssp_; + shmem_size_params proba_ssp_; // vector_leaf_ is only used if {class,proba}_ssp_.leaf_algo is VECTOR_LEAF, // otherwise it is empty - rmm::device_uvector vector_leaf_; + rmm::device_uvector vector_leaf_; cat_sets_device_owner cat_sets_; }; @@ -344,8 +354,6 @@ struct opt_into_arch_dependent_shmem : dispatch_functor { template > void run(predict_params p) { - static_assert(std::is_same_v, - "real_t must be float; to be removed in the following pull requests"); auto kernel = infer_k { }; template -struct dense_forest> : forest { +struct dense_forest> : forest { using node_t = dense_node; - dense_forest(const raft::handle_t& h) : forest(h), nodes_(0, h.get_stream()) {} + dense_forest(const raft::handle_t& h) : forest(h), nodes_(0, h.get_stream()) {} void transform_trees(const node_t* nodes) { @@ -374,14 +382,14 @@ struct dense_forest> : forest { roots of all trees (node 2), and so on. */ int global_node = 0; - for (int tree = 0; tree < num_trees_; ++tree) { + for (int tree = 0; tree < this->num_trees_; ++tree) { int tree_node = 0; // the counters `level` and `branch` are not used for computing node // indices, they are only here to highlight the node ordering within // each tree - for (int level = 0; level <= depth_; ++level) { + for (int level = 0; level <= this->depth_; ++level) { for (int branch = 0; branch < 1 << level; ++branch) { - h_nodes_[tree_node * num_trees_ + tree] = nodes[global_node]; + h_nodes_[tree_node * this->num_trees_ + tree] = nodes[global_node]; ++tree_node; ++global_node; } @@ -398,13 +406,13 @@ struct dense_forest> : forest { const node_t* nodes, const forest_params_t* params) { - init_common(h, cat_sets, vector_leaf, params); - if (algo_ == algo_t::NAIVE) algo_ = algo_t::BATCH_TREE_REORG; + this->init_common(h, cat_sets, vector_leaf, params); + if (this->algo_ == algo_t::NAIVE) this->algo_ = algo_t::BATCH_TREE_REORG; - int num_nodes = forest_num_nodes(num_trees_, depth_); + int num_nodes = forest_num_nodes(this->num_trees_, this->depth_); nodes_.resize(num_nodes, h.get_stream()); h_nodes_.resize(num_nodes); - if (algo_ == algo_t::NAIVE) { + if (this->algo_ == algo_t::NAIVE) { std::copy(nodes, nodes + num_nodes, h_nodes_.begin()); } else { transform_trees(nodes); @@ -416,8 +424,8 @@ struct dense_forest> : forest { h.get_stream())); // predict_proba is a runtime parameter, and opt-in is unconditional - dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(max_shm_), - static_cast(class_ssp_)); + dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(this->max_shm_), + static_cast(this->class_ssp_)); // copy must be finished before freeing the host data h.sync_stream(); h_nodes_.clear(); @@ -426,21 +434,19 @@ struct dense_forest> : forest { virtual void infer(predict_params params, cudaStream_t stream) override { - storage forest(cat_sets_.accessor(), - reinterpret_cast(vector_leaf_.data()), + storage forest(this->cat_sets_.accessor(), + this->vector_leaf_.data(), nodes_.data(), - num_trees_, - algo_ == algo_t::NAIVE ? tree_num_nodes(depth_) : 1, - algo_ == algo_t::NAIVE ? 1 : num_trees_); - static_assert(std::is_same_v, - "real_t must be float; to be removed in the following pull requests"); + this->num_trees_, + this->algo_ == algo_t::NAIVE ? tree_num_nodes(this->depth_) : 1, + this->algo_ == algo_t::NAIVE ? 1 : this->num_trees_); fil::infer(forest, params, stream); } virtual void free(const raft::handle_t& h) override { nodes_.release(); - forest::free(h); + forest::free(h); } rmm::device_uvector nodes_; @@ -448,28 +454,33 @@ struct dense_forest> : forest { }; template -struct sparse_forest : forest { +struct sparse_forest : forest { + using real_type = typename node_t::real_type; + sparse_forest(const raft::handle_t& h) - : forest(h), trees_(0, h.get_stream()), nodes_(0, h.get_stream()) + : forest(h), trees_(0, h.get_stream()), nodes_(0, h.get_stream()) { } void init(const raft::handle_t& h, const categorical_sets& cat_sets, - const std::vector& vector_leaf, + const std::vector& vector_leaf, const int* trees, const node_t* nodes, const forest_params_t* params) { - init_common(h, cat_sets, vector_leaf, params); - if (algo_ == algo_t::ALGO_AUTO) algo_ = algo_t::NAIVE; - depth_ = 0; // a placeholder value - num_nodes_ = params->num_nodes; + this->init_common(h, cat_sets, vector_leaf, params); + if (this->algo_ == algo_t::ALGO_AUTO) this->algo_ = algo_t::NAIVE; + this->depth_ = 0; // a placeholder value + num_nodes_ = params->num_nodes; // trees - trees_.resize(num_trees_, h.get_stream()); - RAFT_CUDA_TRY(cudaMemcpyAsync( - trees_.data(), trees, sizeof(int) * num_trees_, cudaMemcpyHostToDevice, h.get_stream())); + trees_.resize(this->num_trees_, h.get_stream()); + RAFT_CUDA_TRY(cudaMemcpyAsync(trees_.data(), + trees, + sizeof(int) * this->num_trees_, + cudaMemcpyHostToDevice, + h.get_stream())); // nodes nodes_.resize(num_nodes_, h.get_stream()); @@ -477,25 +488,23 @@ struct sparse_forest : forest { nodes_.data(), nodes, sizeof(node_t) * num_nodes_, cudaMemcpyHostToDevice, h.get_stream())); // predict_proba is a runtime parameter, and opt-in is unconditional - dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(max_shm_), - static_cast(class_ssp_)); + dispatch_on_fil_template_params(opt_into_arch_dependent_shmem>(this->max_shm_), + static_cast(this->class_ssp_)); } virtual void infer(predict_params params, cudaStream_t stream) override { - storage forest(cat_sets_.accessor(), - reinterpret_cast(vector_leaf_.data()), + storage forest(this->cat_sets_.accessor(), + this->vector_leaf_.data(), trees_.data(), nodes_.data(), - num_trees_); - static_assert(std::is_same_v, - "real_t must be float; to be removed in the following pull requests"); + this->num_trees_); fil::infer(forest, params, stream); } void free(const raft::handle_t& h) override { - forest::free(h); + forest::free(h); trees_.release(); nodes_.release(); } @@ -583,11 +592,11 @@ void check_params(const forest_params_t* params, bool dense) /** initializes a forest of any type * When fil_node_t == dense_node, const int* trees is ignored */ -template +template void init(const raft::handle_t& h, - forest_t* pf, + forest_t* pf, const categorical_sets& cat_sets, - const std::vector& vector_leaf, + const std::vector& vector_leaf, const int* trees, const fil_node_t* nodes, const forest_params_t* params) @@ -595,52 +604,59 @@ void init(const raft::handle_t& h, check_params(params, node_traits::IS_DENSE); using forest_type = typename node_traits::forest; forest_type* f = new forest_type(h); - static_assert(std::is_same_v, - "real_t must be float; to be removed in the following pull requests"); f->init(h, cat_sets, vector_leaf, trees, nodes, params); *pf = f; } -// explicit instantiations for init_sparse() -template void init>(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const sparse_node16* nodes, - const forest_params_t* params); - -template void init(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const sparse_node8* nodes, - const forest_params_t* params); - -template void init>(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const dense_node* nodes, - const forest_params_t* params); - -void free(const raft::handle_t& h, forest_t f) +// explicit instantiations for init() +template void init, float>(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const dense_node* nodes, + const forest_params_t* params); +template void init, float>(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const sparse_node16* nodes, + const forest_params_t* params); +template void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const sparse_node8* nodes, + const forest_params_t* params); + +template +void free(const raft::handle_t& h, forest_t f) { f->free(h); delete f; } +template void free(const raft::handle_t& h, forest_t f); + +template void predict(const raft::handle_t& h, - forest_t f, - float* preds, - const float* data, + forest_t f, + real_t* preds, + const real_t* data, size_t num_rows, bool predict_proba) { f->predict(h, preds, data, num_rows, predict_proba); } +template void predict(const raft::handle_t& h, + forest_t f, + float* preds, + const float* data, + size_t num_rows, + bool predict_proba); + } // namespace fil } // namespace ML diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index f03543ba11..0bb15aad8c 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -48,7 +48,7 @@ struct vec; template struct Vectorized { BinaryOp op; - __device__ Vectorized(BinaryOp op_) : op(op_) {} + __host__ __device__ Vectorized(BinaryOp op_) : op(op_) {} template constexpr __host__ __device__ __forceinline__ vec operator()(vec a, vec b) const @@ -63,7 +63,7 @@ struct Vectorized { template constexpr __host__ __device__ Vectorized vectorized(BinaryOp op) { - return op; + return Vectorized(op); } template @@ -108,23 +108,23 @@ struct best_margin_label : cub::KeyValuePair { } }; -template -__device__ __forceinline__ vec> to_vec(int c, - vec margin) +template +__device__ __forceinline__ vec> to_vec(int c, + vec margin) { - vec> ret; + vec> ret; CUDA_PRAGMA_UNROLL for (int i = 0; i < NITEMS; ++i) - ret[i] = best_margin_label(c, margin[i]); + ret[i] = best_margin_label(c, margin[i]); return ret; } struct ArgMax { - template - __host__ __device__ __forceinline__ vec> operator()( - vec> a, vec> b) const + template + __host__ __device__ __forceinline__ vec> operator()( + vec> a, vec> b) const { - vec> c; + vec> c; CUDA_PRAGMA_UNROLL for (int i = 0; i < NITEMS; i++) c[i] = cub::ArgMax()(a[i], b[i]); @@ -156,10 +156,8 @@ __device__ __forceinline__ vec tree_leaf_output(tree_type t } template -__device__ __forceinline__ vec infer_one_tree(tree_type tree, - const float* input, - int cols, - int n_rows) +__device__ __forceinline__ vec infer_one_tree( + tree_type tree, const typename tree_type::real_type* input, int cols, int n_rows) { // find the leaf nodes for each row int curr[NITEMS]; @@ -187,10 +185,8 @@ __device__ __forceinline__ vec infer_one_tree(tree_type tre } template -__device__ __forceinline__ vec<1, output_type> infer_one_tree(tree_type tree, - const float* input, - int cols, - int rows) +__device__ __forceinline__ vec<1, output_type> infer_one_tree( + tree_type tree, const typename tree_type::real_type* input, int cols, int rows) { int curr = 0; for (;;) { @@ -221,19 +217,19 @@ host code. See https://rapids.ai/start.html as well as cmake defaults. */ // values below are defaults as of this change. -template +template size_t block_reduce_footprint_host() { return sizeof( typename cub:: - BlockReduce, FIL_TPB, cub::BLOCK_REDUCE_WARP_REDUCTIONS, 1, 1, 600>:: + BlockReduce, FIL_TPB, cub::BLOCK_REDUCE_WARP_REDUCTIONS, 1, 1, 600>:: TempStorage); } -template +template size_t block_reduce_best_class_footprint_host() { - return sizeof(typename cub::BlockReduce>, + return sizeof(typename cub::BlockReduce>, FIL_TPB, cub::BLOCK_REDUCE_WARP_REDUCTIONS, 1, @@ -251,9 +247,10 @@ __device__ __forceinline__ T block_reduce(T value, BinaryOp op, void* storage) } template // = FLOAT_UNARY_BINARY struct tree_aggregator_t { - vec acc; + vec acc; void* tmp_storage; /** shared memory footprint of the accumulator during @@ -265,8 +262,8 @@ struct tree_aggregator_t { int log2_threads_per_tree, bool predict_proba) { - return log2_threads_per_tree != 0 ? FIL_TPB * NITEMS * sizeof(float) - : block_reduce_footprint_host(); + return log2_threads_per_tree != 0 ? FIL_TPB * NITEMS * sizeof(real_t) + : block_reduce_footprint_host(); } /** shared memory footprint of the accumulator during @@ -280,19 +277,19 @@ struct tree_aggregator_t { __device__ __forceinline__ tree_aggregator_t(predict_params params, void* accumulate_workspace, void* finalize_workspace, - float* vector_leaf) + real_t* vector_leaf) : tmp_storage(finalize_workspace) { } - __device__ __forceinline__ void accumulate(vec single_tree_prediction, + __device__ __forceinline__ void accumulate(vec single_tree_prediction, int tree, int thread_num_rows) { acc += single_tree_prediction; } - __device__ INLINE_CONFIG void finalize(float* block_out, + __device__ INLINE_CONFIG void finalize(real_t* block_out, int block_num_rows, int output_stride, output_t transform, @@ -305,7 +302,7 @@ struct tree_aggregator_t { if (log2_threads_per_tree == 0) { acc = block_reduce(acc, vectorized(cub::Sum()), tmp_storage); } else { - auto per_thread = (vec*)tmp_storage; + auto per_thread = (vec*)tmp_storage; per_thread[threadIdx.x] = acc; __syncthreads(); // We have two pertinent cases for splitting FIL_TPB == 256 values: @@ -351,13 +348,13 @@ __device__ __forceinline__ auto allreduce_shmem(Iterator begin, // *begin and *end shall be struct vec // tmp_storage may overlap shared memory addressed by [begin, end) -template +template __device__ __forceinline__ void write_best_class( - Iterator begin, Iterator end, void* tmp_storage, float* out, int num_rows) + Iterator begin, Iterator end, void* tmp_storage, real_t* out, int num_rows) { // reduce per-class candidate margins to one best class candidate // per thread (for each of the NITEMS rows) - auto best = vecNITEMS, best_margin_label>(); + auto best = vecNITEMS, best_margin_label>(); for (int c = threadIdx.x; c < end - begin; c += blockDim.x) best = vectorized(cub::ArgMax())(best, to_vec(c, begin[c])); // [begin, end) may overlap tmp_storage @@ -372,7 +369,13 @@ __device__ __forceinline__ void write_best_class( } /// needed for softmax -__device__ float shifted_exp(float margin, float max) { return expf(margin - max); } +struct shifted_exp { + template + __device__ double operator()(real_t margin, real_t max) const + { + return exp(margin - max); + } +}; // *begin and *end shall be struct vec // tmp_storage may NOT overlap shared memory addressed by [begin, end) @@ -380,10 +383,10 @@ template __device__ __forceinline__ void block_softmax(Iterator begin, Iterator end, void* tmp_storage) { // subtract max before exponentiating for numerical stability - typedef typename std::iterator_traits::value_type value_type; - value_type max = allreduce_shmem(begin, end, vectorized(cub::Max()), tmp_storage); + using value_type = typename std::iterator_traits::value_type; + value_type max = allreduce_shmem(begin, end, vectorized(cub::Max()), tmp_storage); for (Iterator it = begin + threadIdx.x; it < end; it += blockDim.x) - *it = vectorized(shifted_exp)(*it, max); + *it = vectorized(shifted_exp())(*it, max); // sum of exponents value_type soe = allreduce_shmem(begin, end, vectorized(cub::Sum()), tmp_storage); // softmax phase 2: normalization @@ -393,13 +396,13 @@ __device__ __forceinline__ void block_softmax(Iterator begin, Iterator end, void // *begin and *end shall be struct vec // tmp_storage may NOT overlap shared memory addressed by [begin, end) -template +template __device__ __forceinline__ void normalize_softmax_and_write(Iterator begin, Iterator end, output_t transform, int trees_per_class, void* tmp_storage, - float* out, + real_t* out, int num_rows) { if ((transform & output_t::AVG) != 0) { @@ -418,13 +421,13 @@ __device__ __forceinline__ void normalize_softmax_and_write(Iterator begin, // *begin and *end shall be struct vec // tmp_storage may NOT overlap shared memory addressed by [begin, end) // in case num_outputs > 1 -template +template __device__ __forceinline__ void class_margins_to_global_memory(Iterator begin, Iterator end, output_t transform, int trees_per_class, void* tmp_storage, - float* out, + real_t* out, int num_rows, int num_outputs) { @@ -437,11 +440,11 @@ __device__ __forceinline__ void class_margins_to_global_memory(Iterator begin, } } -template -struct tree_aggregator_t { - vec acc; +template +struct tree_aggregator_t { + vec acc; int num_classes; - vec* per_thread; + vec* per_thread; void* tmp_storage; static size_t smem_finalize_footprint(size_t data_row_size, @@ -449,9 +452,9 @@ struct tree_aggregator_t { int log2_threads_per_tree, bool predict_proba) { - size_t phase1 = (FIL_TPB - FIL_TPB % num_classes) * sizeof(vec); - size_t phase2 = predict_proba ? block_reduce_footprint_host() - : block_reduce_best_class_footprint_host(); + size_t phase1 = (FIL_TPB - FIL_TPB % num_classes) * sizeof(vec); + size_t phase2 = predict_proba ? block_reduce_footprint_host() + : block_reduce_best_class_footprint_host(); return predict_proba ? phase1 + phase2 : std::max(phase1, phase2); } @@ -460,21 +463,21 @@ struct tree_aggregator_t { __device__ __forceinline__ tree_aggregator_t(predict_params params, void* accumulate_workspace, void* finalize_workspace, - float* vector_leaf) + real_t* vector_leaf) : num_classes(params.num_classes), - per_thread((vec*)finalize_workspace), + per_thread((vec*)finalize_workspace), tmp_storage(params.predict_proba ? per_thread + num_classes : finalize_workspace) { } - __device__ __forceinline__ void accumulate(vec single_tree_prediction, + __device__ __forceinline__ void accumulate(vec single_tree_prediction, int tree, int thread_num_rows) { acc += single_tree_prediction; } - __device__ INLINE_CONFIG void finalize(float* out, + __device__ INLINE_CONFIG void finalize(real_t* out, int num_rows, int num_outputs, output_t transform, @@ -500,11 +503,11 @@ struct tree_aggregator_t { } }; -template -struct tree_aggregator_t { - vec acc; +template +struct tree_aggregator_t { + vec acc; /// at first, per class margin, then, possibly, different softmax partials - vec* per_class_margin; + vec* per_class_margin; void* tmp_storage; int num_classes; @@ -514,30 +517,30 @@ struct tree_aggregator_t { bool predict_proba) { size_t phase1 = data_row_size + smem_accumulate_footprint(num_classes); - size_t phase2 = predict_proba ? block_reduce_footprint_host() - : block_reduce_best_class_footprint_host(); + size_t phase2 = predict_proba ? block_reduce_footprint_host() + : block_reduce_best_class_footprint_host(); return predict_proba ? phase1 + phase2 : std::max(phase1, phase2); } static __host__ __device__ size_t smem_accumulate_footprint(int num_classes) { - return num_classes * sizeof(vec); + return num_classes * sizeof(vec); } __device__ __forceinline__ tree_aggregator_t(predict_params params, void* accumulate_workspace, void* finalize_workspace, - float* vector_leaf) - : per_class_margin((vec*)accumulate_workspace), + real_t* vector_leaf) + : per_class_margin((vec*)accumulate_workspace), tmp_storage(params.predict_proba ? per_class_margin + num_classes : finalize_workspace), num_classes(params.num_classes) { for (int c = threadIdx.x; c < num_classes; c += blockDim.x) - per_class_margin[c] = vec(0); + per_class_margin[c] = vec(0); // __syncthreads() is called in infer_k } - __device__ __forceinline__ void accumulate(vec single_tree_prediction, + __device__ __forceinline__ void accumulate(vec single_tree_prediction, int tree, int thread_num_rows) { @@ -546,7 +549,7 @@ struct tree_aggregator_t { __syncthreads(); } - __device__ INLINE_CONFIG void finalize(float* out, + __device__ INLINE_CONFIG void finalize(real_t* out, int num_rows, int num_outputs, output_t transform, @@ -564,17 +567,17 @@ struct tree_aggregator_t { } }; -template -struct tree_aggregator_t { +template +struct tree_aggregator_t { // per_class_margin is a row-major matrix // of size num_threads_per_class * num_classes // used to acccumulate class values - vec* per_class_margin; + vec* per_class_margin; vec* vector_leaf_indices; int* thread_num_rows; int num_classes; int num_threads_per_class; - float* vector_leaf; + real_t* vector_leaf; void* tmp_storage; static size_t smem_finalize_footprint(size_t data_row_size, @@ -583,20 +586,20 @@ struct tree_aggregator_t { bool predict_proba) { size_t phase1 = data_row_size + smem_accumulate_footprint(num_classes); - size_t phase2 = predict_proba ? block_reduce_footprint_host() - : block_reduce_best_class_footprint_host(); + size_t phase2 = predict_proba ? block_reduce_footprint_host() + : block_reduce_best_class_footprint_host(); return predict_proba ? phase1 + phase2 : std::max(phase1, phase2); } static size_t smem_accumulate_footprint(int num_classes) { - return sizeof(vec) * num_classes * max(1, FIL_TPB / num_classes) + + return sizeof(vec) * num_classes * max(1, FIL_TPB / num_classes) + sizeof(vec) * FIL_TPB + sizeof(int) * FIL_TPB; } __device__ __forceinline__ tree_aggregator_t(predict_params params, void* accumulate_workspace, void* finalize_workspace, - float* vector_leaf) + real_t* vector_leaf) : num_classes(params.num_classes), num_threads_per_class(max(1, blockDim.x / params.num_classes)), vector_leaf(vector_leaf), @@ -604,15 +607,15 @@ struct tree_aggregator_t { { // Assign workspace char* ptr = (char*)accumulate_workspace; - per_class_margin = (vec*)ptr; - ptr += sizeof(vec) * num_classes * num_threads_per_class; + per_class_margin = (vec*)ptr; + ptr += sizeof(vec) * num_classes * num_threads_per_class; vector_leaf_indices = (vec*)ptr; ptr += sizeof(vec) * blockDim.x; thread_num_rows = (int*)ptr; // Initialise shared memory for (int i = threadIdx.x; i < num_classes * num_threads_per_class; i += blockDim.x) { - per_class_margin[i] = vec(); + per_class_margin[i] = vec(); } vector_leaf_indices[threadIdx.x] = vec(); thread_num_rows[threadIdx.x] = 0; @@ -639,13 +642,13 @@ struct tree_aggregator_t { // we have num_classes threads for each j for (int j = i / num_classes; j < blockDim.x; j += num_threads_per_class) { for (int item = 0; item < thread_num_rows[j]; ++item) { - float pred = vector_leaf[vector_leaf_indices[j][item] * num_classes + c]; + real_t pred = vector_leaf[vector_leaf_indices[j][item] * num_classes + c]; per_class_margin[i][item] += pred; } } } } - __device__ INLINE_CONFIG void finalize(float* out, + __device__ INLINE_CONFIG void finalize(real_t* out, int num_rows, int num_outputs, output_t transform, @@ -670,8 +673,8 @@ struct tree_aggregator_t { } }; -template -struct tree_aggregator_t { +template +struct tree_aggregator_t { // could switch to uint16_t to save shared memory // provided raft::myAtomicAdd(short*) simulated with appropriate shifts int* votes; @@ -693,7 +696,7 @@ struct tree_aggregator_t { __device__ __forceinline__ tree_aggregator_t(predict_params params, void* accumulate_workspace, void* finalize_workspace, - float* vector_leaf) + real_t* vector_leaf) : num_classes(params.num_classes), votes((int*)accumulate_workspace) { for (int c = threadIdx.x; c < num_classes; c += FIL_TPB * NITEMS) @@ -714,7 +717,7 @@ struct tree_aggregator_t { } // class probabilities or regression. for regression, num_classes // is just the number of outputs for each data instance - __device__ __forceinline__ void finalize_multiple_outputs(float* out, int num_rows) + __device__ __forceinline__ void finalize_multiple_outputs(real_t* out, int num_rows) { __syncthreads(); for (int c = threadIdx.x; c < num_classes; c += blockDim.x) { @@ -725,7 +728,7 @@ struct tree_aggregator_t { } // using this when predicting a single class label, as opposed to sparse class vector // or class probabilities or regression - __device__ __forceinline__ void finalize_class_label(float* out, int num_rows) + __device__ __forceinline__ void finalize_class_label(real_t* out, int num_rows) { __syncthreads(); // make sure all votes[] are final int item = threadIdx.x; @@ -742,7 +745,7 @@ struct tree_aggregator_t { out[row] = best_class; } } - __device__ INLINE_CONFIG void finalize(float* out, + __device__ INLINE_CONFIG void finalize(real_t* out, int num_rows, int num_outputs, output_t transform, @@ -758,8 +761,9 @@ struct tree_aggregator_t { } }; -__device__ INLINE_CONFIG void load_data(float* sdata, - const float* block_input, +template +__device__ INLINE_CONFIG void load_data(real_t* sdata, + const real_t* block_input, predict_params params, int rows_per_block, int block_num_rows) @@ -792,8 +796,9 @@ template __global__ void infer_k(storage_type forest, predict_params params) { + using real_t = typename storage_type::real_type; extern __shared__ char smem[]; - float* sdata = (float*)smem; + real_t* sdata = reinterpret_cast(smem); int sdata_stride = params.sdata_stride(); int rows_per_block = NITEMS << params.log2_threads_per_tree; int num_cols = params.num_cols; @@ -802,11 +807,12 @@ __global__ void infer_k(storage_type forest, predict_params params) block_row0 += rows_per_block * gridDim.x) { int block_num_rows = max(0, (int)min((int64_t)rows_per_block, (int64_t)params.num_rows - block_row0)); - const float* block_input = reinterpret_cast(params.data) + block_row0 * num_cols; + const real_t* block_input = + reinterpret_cast(params.data) + block_row0 * num_cols; if constexpr (cols_in_shmem) load_data(sdata, block_input, params, rows_per_block, block_num_rows); - tree_aggregator_t acc( + tree_aggregator_t acc( params, (char*)sdata + params.cols_shmem_size(), sdata, forest.vector_leaf_); __syncthreads(); // for both row cache init and acc init @@ -822,7 +828,7 @@ __global__ void infer_k(storage_type forest, predict_params params) and is made exact below. Same with thread_num_rows > 0 */ - typedef typename leaf_output_t::T pred_t; + using pred_t = typename leaf_output_t::T; vec prediction; if (tree < forest.num_trees() && thread_num_rows != 0) { prediction = infer_one_tree( @@ -835,7 +841,7 @@ __global__ void infer_k(storage_type forest, predict_params params) // Dummy threads can be marked as having 0 rows acc.accumulate(prediction, tree, tree < forest.num_trees() ? thread_num_rows : 0); } - acc.finalize(reinterpret_cast(params.preds) + params.num_outputs * block_row0, + acc.finalize(reinterpret_cast(params.preds) + params.num_outputs * block_row0, block_num_rows, params.num_outputs, params.transform, @@ -845,13 +851,13 @@ __global__ void infer_k(storage_type forest, predict_params params) } } -template +template size_t shmem_size_params::get_smem_footprint() { - size_t finalize_footprint = tree_aggregator_t::smem_finalize_footprint( + size_t finalize_footprint = tree_aggregator_t::smem_finalize_footprint( cols_shmem_size(), num_classes, log2_threads_per_tree, predict_proba); size_t accumulate_footprint = - tree_aggregator_t::smem_accumulate_footprint(num_classes) + + tree_aggregator_t::smem_accumulate_footprint(num_classes) + cols_shmem_size(); return std::max(accumulate_footprint, finalize_footprint); } @@ -859,7 +865,20 @@ size_t shmem_size_params::get_smem_footprint() template int compute_smem_footprint::run(predict_params ssp) { - return ssp.template get_smem_footprint(); + switch (ssp.sizeof_real) { + case 4: + return ssp + .template get_smem_footprint(); + case 8: + return ssp + .template get_smem_footprint(); + default: + ASSERT(false, + "internal error: sizeof_real == %d, but must be 4 or 8", + static_cast(ssp.sizeof_real)); + // unreachable + return 0; + } } // make sure to instantiate all possible get_smem_footprint instantiations @@ -895,15 +914,15 @@ void infer(storage_type forest, predict_params params, cudaStream_t stream) dispatch_on_fil_template_params(infer_k_storage_template(forest, stream), params); } -template void infer>>(storage> forest, - predict_params params, - cudaStream_t stream); -template void infer>>(storage> forest, - predict_params params, - cudaStream_t stream); -template void infer>(storage forest, - predict_params params, - cudaStream_t stream); +template void infer(dense_storage_f32 forest, + predict_params params, + cudaStream_t stream); +template void infer(sparse_storage16_f32 forest, + predict_params params, + cudaStream_t stream); +template void infer(sparse_storage8 forest, + predict_params params, + cudaStream_t stream); } // namespace fil } // namespace ML diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 5ae613b559..633eefde04 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -100,9 +100,9 @@ union val_t { }; /** base_node contains common implementation details for dense and sparse nodes */ -template -struct alignas(2 * sizeof(real_t_)) base_node { - using real_t = real_t_; // floating-point type +template +struct alignas(2 * sizeof(real_t)) base_node { + using real_type = real_t; // floating-point type /** val, for parent nodes, is a threshold or category list offset. For leaf nodes, it is the tree prediction (see see leaf_output_t::T) */ val_t val; @@ -237,12 +237,12 @@ struct sparse_forest; template struct node_traits { - using real_t = typename node_t::real_t; + using real_type = typename node_t::real_type; using storage = ML::fil::storage; using forest = sparse_forest; static const bool IS_DENSE = false; - static const storage_type_t storage_type_enum = - std::is_same, node_t>() ? SPARSE : SPARSE8; + static constexpr storage_type_t storage_type_enum = + std::is_same_v, node_t> ? SPARSE : SPARSE8; template static void check(const treelite::ModelImpl& model); }; @@ -341,10 +341,10 @@ struct forest_params_t { output_t output; // threshold is used to for classification if leaf_algo == FLOAT_UNARY_BINARY && (output & // OUTPUT_CLASS) != 0 && !predict_proba, and is ignored otherwise - float threshold; + double threshold; // global_bias is added to the sum of tree predictions // (after averaging, if it is used, but before any further transformations) - float global_bias; + double global_bias; // only used for CATEGORICAL_LEAF inference. since we're storing the // labels in leaves instead of the whole vector, this keeps track // of the number of classes @@ -392,8 +392,8 @@ struct categorical_sets { // set count is due to tree_idx + node_within_tree_idx are both ints, hence uint32_t result template - __host__ __device__ __forceinline__ int category_matches(node_t node, - typename node_t::real_t category) const + __host__ __device__ __forceinline__ int category_matches( + node_t node, typename node_t::real_type category) const { // standard boolean packing. This layout has better ILP // node.set() is global across feature IDs and is an offset (as opposed @@ -409,7 +409,7 @@ struct categorical_sets { FIL will reject a model where an integer within [0, fid_num_cats] cannot be represented precisely as a 32-bit float. */ - using real_t = typename node_t::real_t; + using real_t = typename node_t::real_type; return category < static_cast(fid_num_cats[node.fid()]) && category >= real_t(0) && fetch_bit(bits + node.set(), static_cast(static_cast(category))); } @@ -431,7 +431,7 @@ struct tree_base { template __host__ __device__ __forceinline__ int child_index(const node_t& node, int node_idx, - typename node_t::real_t val) const + typename node_t::real_type val) const { bool cond; @@ -561,11 +561,11 @@ struct cat_sets_device_owner { * @param params pointer to parameters used to initialize the forest * @param vector_leaf optional vector leaves */ -template +template void init(const raft::handle_t& h, - forest_t* pf, + forest_t* pf, const categorical_sets& cat_sets, - const std::vector& vector_leaf, + const std::vector& vector_leaf, const int* trees, const fil_node_t* nodes, const forest_params_t* params); diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 56536802f5..68634fe26a 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -631,7 +631,7 @@ struct tl2fil_t { } /// initializes FIL forest object, to be ready to infer - void init_forest(const raft::handle_t& handle, forest_t* pforest) + void init_forest(const raft::handle_t& handle, forest_t* pforest) { ML::fil::init( handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), ¶ms_); @@ -646,7 +646,7 @@ struct tl2fil_t { template void convert(const raft::handle_t& handle, - forest_t* pforest, + forest_t* pforest, const tl::ModelImpl& model, const treelite_params_t& tl_params) { @@ -655,17 +655,23 @@ void convert(const raft::handle_t& handle, tl2fil.init_forest(handle, pforest); } +template +constexpr bool type_supported() +{ + // not using std::is_floating_point because we did not instantiate fp16-based nodes/trees/forests + return std::is_same_v || std::is_same_v; +} + template void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_t* pforest, const tl::ModelImpl& model, const treelite_params_t* tl_params) { // Invariants on threshold and leaf types - static_assert(std::is_same::value || std::is_same::value, + static_assert(type_supported(), "Model must contain float32 or float64 thresholds for splits"); - ASSERT((std::is_same::value || std::is_same::value), - "Models with integer leaf output are not yet supported"); + ASSERT(type_supported(), "Models with integer leaf output are not yet supported"); // Display appropriate warnings when float64 values are being casted into // float32, as FIL only supports inferencing with float32 for the time being if (std::is_same::value || std::is_same::value) { @@ -674,6 +680,8 @@ void from_treelite(const raft::handle_t& handle, "doesn't support inferencing models with float64 values. " "This may lead to predictions with reduced accuracy."); } + // same as std::common_type: float+double=double, float+int64_t=float + using real_t = decltype(threshold_t(0) + leaf_t(0)); storage_type_t storage_type = tl_params->storage_type; // build dense trees by default @@ -705,7 +713,7 @@ void from_treelite(const raft::handle_t& handle, } void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_t* pforest, ModelHandle model, const treelite_params_t* tl_params) { diff --git a/cpp/test/sg/fil_child_index_test.cu b/cpp/test/sg/fil_child_index_test.cu index c9b322346e..2ab3eed56e 100644 --- a/cpp/test/sg/fil_child_index_test.cu +++ b/cpp/test/sg/fil_child_index_test.cu @@ -142,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const ChildIndexTestParams& ps) template class ChildIndexTest : public testing::TestWithParam { - using real_t = typename fil_node_t::real_t; + using real_t = typename fil_node_t::real_type; protected: void check() diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index f36bb3a45b..270e0dce66 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -70,9 +70,9 @@ struct FilTestParams { // Order Of Magnitude for maximum matching category for categorical nodes float max_magnitude_of_matching_cat = 1.0f; // output parameters - output_t output = output_t::RAW; - float threshold = 0.0f; - float global_bias = 0.0f; + output_t output = output_t::RAW; + double threshold = 0.0f; + double global_bias = 0.0f; // runtime parameters int blocks_per_sm = 0; int threads_per_tree = 1; @@ -553,12 +553,12 @@ class BaseFilTest : public testing::TestWithParam { handle.sync_stream(); } - virtual void init_forest(fil::forest_t* pforest) = 0; + virtual void init_forest(fil::forest_t* pforest) = 0; void predict_on_gpu() { - auto stream = handle.get_stream(); - fil::forest_t forest = nullptr; + auto stream = handle.get_stream(); + fil::forest_t forest = nullptr; init_forest(&forest); // predict @@ -678,7 +678,7 @@ class BasePredictFilTest : public BaseFilTest { } } - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { constexpr bool IS_DENSE = node_traits::IS_DENSE; std::vector init_nodes; @@ -806,7 +806,7 @@ class TreeliteFilTest : public BaseFilTest { return key; } - void init_forest_impl(fil::forest_t* pforest, fil::storage_type_t storage_type) + void init_forest_impl(fil::forest_t* pforest, fil::storage_type_t storage_type) { auto stream = handle.get_stream(); bool random_forest_flag = (ps.output & fil::output_t::AVG) != 0; @@ -888,7 +888,7 @@ class TreeliteFilTest : public BaseFilTest { class TreeliteDenseFilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { init_forest_impl(pforest, fil::storage_type_t::DENSE); } @@ -896,7 +896,7 @@ class TreeliteDenseFilTest : public TreeliteFilTest { class TreeliteSparse16FilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { init_forest_impl(pforest, fil::storage_type_t::SPARSE); } @@ -904,7 +904,7 @@ class TreeliteSparse16FilTest : public TreeliteFilTest { class TreeliteSparse8FilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { init_forest_impl(pforest, fil::storage_type_t::SPARSE8); } @@ -912,7 +912,7 @@ class TreeliteSparse8FilTest : public TreeliteFilTest { class TreeliteAutoFilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { init_forest_impl(pforest, fil::storage_type_t::AUTO); } diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index d707fe7970..345770efa1 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -172,7 +172,7 @@ auto FilPredict(const raft::handle_t& handle, 1, 0, nullptr}; - fil::forest_t fil_forest; + fil::forest_t fil_forest; fil::from_treelite(handle, &fil_forest, model, &tl_params); fil::predict(handle, fil_forest, pred->data().get(), X_transpose, params.n_rows, false); return pred; @@ -191,7 +191,7 @@ auto FilPredictProba(const raft::handle_t& handle, build_treelite_forest(&model, forest, params.n_cols); fil::treelite_params_t tl_params{ fil::algo_t::ALGO_AUTO, 0, 0.0f, fil::storage_type_t::AUTO, 8, 1, 0, nullptr}; - fil::forest_t fil_forest; + fil::forest_t fil_forest; fil::from_treelite(handle, &fil_forest, model, &tl_params); fil::predict(handle, fil_forest, pred->data().get(), X_transpose, params.n_rows, true); return pred; @@ -557,7 +557,7 @@ TEST(RfTests, IntegerOverflow) 1, 0, nullptr}; - fil::forest_t fil_forest; + fil::forest_t fil_forest; fil::from_treelite(handle, &fil_forest, model, &tl_params); fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); } diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 4894135791..a9e0b79e1a 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -190,10 +190,12 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": SPARSE, SPARSE8 - cdef struct forest: + cdef cppclass forest[real_t]: pass - ctypedef forest* forest_t + # TODO(canonizer): use something like + # ctypedef forest[real_t]* forest_t[real_t] + # once it is supported in Cython cdef struct treelite_params_t: algo_t algo @@ -215,25 +217,25 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": # this affects inference performance and will become configurable soon char** pforest_shape_str - cdef void free(handle_t& handle, - forest_t) + cdef void free[real_t](handle_t& handle, + forest[real_t]*) - cdef void predict(handle_t& handle, - forest_t, - float*, - float*, - size_t, - bool) except + + cdef void predict[real_t](handle_t& handle, + forest[real_t]*, + real_t*, + real_t*, + size_t, + bool) except + - cdef forest_t from_treelite(handle_t& handle, - forest_t*, - ModelHandle, - treelite_params_t*) except + + cdef forest[float]* from_treelite(handle_t& handle, + forest[float]**, + ModelHandle, + treelite_params_t*) except + cdef class ForestInference_impl(): cdef object handle - cdef forest_t forest_data + cdef forest[float]* forest_data cdef size_t num_class cdef bool output_class cdef char* shape_str