From 701956ef7bd90ad5217931360fd524f743375acc Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 4 Nov 2021 15:44:52 -0700 Subject: [PATCH 01/24] unified tree2fil --- cpp/src/fil/internal.cuh | 9 ++- cpp/src/fil/treelite_import.cu | 115 ++++++--------------------------- 2 files changed, 29 insertions(+), 95 deletions(-) diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 1d182fc8f5..b8e05e4b51 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -149,9 +149,16 @@ __host__ __device__ __forceinline__ val_t base_node::output() const /** dense_node is a single node of a dense forest */ struct alignas(8) dense_node : base_node { dense_node() = default; - dense_node(val_t output, val_t split, int fid, bool def_left, bool is_leaf, bool is_categorical) + dense_node(val_t output, + val_t split, + int fid, + bool def_left, + bool is_leaf, + bool is_categorical, + int left_index = -1) : base_node(output, split, fid, def_left, is_leaf, is_categorical) { + // ignoring left_index, this is useful to unify import from treelite } /** index of the left child, where curr is the index of the current node */ __host__ __device__ int left(int curr) const { return 2 * curr + 1; } diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 57a80d8471..87fd29b9f9 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -329,96 +329,23 @@ conversion_state tl2fil_inner_node(int fil_left_child, } else { ASSERT(false, "only numerical and categorical split nodes are supported"); } - fil_node_t node; - if constexpr (std::is_same()) { - node = fil_node_t({}, split, feature_id, default_left, false, is_categorical); - } else { - node = fil_node_t({}, split, feature_id, default_left, false, is_categorical, fil_left_child); - } + fil_node_t node({}, split, feature_id, default_left, false, is_categorical, fil_left_child); return conversion_state{node, tl_left, tl_right}; } -template -void node2fil_dense(std::vector* pnodes, - int root, - int cur, - const tl::Tree& tree, - int node_id, - const forest_params_t& forest_params, - std::vector* vector_leaf, - std::size_t* leaf_counter, - cat_sets_owner* cat_sets, - std::size_t* bit_pool_offset) -{ - if (tree.IsLeaf(node_id)) { - (*pnodes)[root + cur] = dense_node({}, {}, 0, false, true, false); - tl2fil_leaf_payload( - &(*pnodes)[root + cur], root + cur, tree, node_id, forest_params, vector_leaf, leaf_counter); - return; - } - - // inner node - int left = 2 * cur + 1; - conversion_state cs = - tl2fil_inner_node(left, tree, node_id, forest_params, cat_sets, bit_pool_offset); - (*pnodes)[root + cur] = cs.node; - node2fil_dense(pnodes, - root, - left, - tree, - cs.tl_left, - forest_params, - vector_leaf, - leaf_counter, - cat_sets, - bit_pool_offset); - node2fil_dense(pnodes, - root, - left + 1, - tree, - cs.tl_right, - forest_params, - vector_leaf, - leaf_counter, - cat_sets, - bit_pool_offset); -} - -template -void tree2fil_dense(std::vector* pnodes, - int root, - const tl::Tree& tree, - std::size_t tree_idx, - const forest_params_t& forest_params, - std::vector* vector_leaf, - std::size_t* leaf_counter, - cat_sets_owner* cat_sets) -{ - node2fil_dense(pnodes, - root, - 0, - tree, - tree_root(tree), - forest_params, - vector_leaf, - leaf_counter, - cat_sets, - &cat_sets->bit_pool_offsets[tree_idx]); -} - template -int tree2fil_sparse(std::vector& nodes, - int root, - const tl::Tree& tree, - std::size_t tree_idx, - const forest_params_t& forest_params, - std::vector* vector_leaf, - std::size_t* leaf_counter, - cat_sets_owner* cat_sets) +int tree2fil(std::vector& nodes, + int root, + const tl::Tree& tree, + std::size_t tree_idx, + const forest_params_t& forest_params, + std::vector* vector_leaf, + std::size_t* leaf_counter, + cat_sets_owner* cat_sets) { typedef std::pair pair_t; std::stack stack; - int built_index = root + 1; + int sparse_index = root + 1; stack.push(pair_t(tree_root(tree), 0)); while (!stack.empty()) { const pair_t& top = stack.top(); @@ -430,8 +357,8 @@ int tree2fil_sparse(std::vector& nodes, // reserve space for child nodes // left is the offset of the left child node relative to the tree root // in the array of all nodes of the FIL sparse forest - int left = built_index - root; - built_index += 2; + int left = std::is_same() ? 2 * cur + 1 : sparse_index - root; + sparse_index += 2; conversion_state cs = tl2fil_inner_node( left, tree, node_id, forest_params, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); nodes[root + cur] = cs.node; @@ -631,14 +558,14 @@ void tl2fil_dense(std::vector* pnodes, pnodes->resize(num_nodes, dense_node()); for (std::size_t i = 0; i < model.trees.size(); ++i) { size_t leaf_counter = max_leaves_per_tree * i; - tree2fil_dense(pnodes, - i * tree_num_nodes(params->depth), - model.trees[i], - i, - *params, - vector_leaf, - &leaf_counter, - cat_sets); + tree2fil(*pnodes, + i * tree_num_nodes(params->depth), + model.trees[i], + i, + *params, + vector_leaf, + &leaf_counter, + cat_sets); } } @@ -727,7 +654,7 @@ void tl2fil_sparse(std::vector* ptrees, for (std::size_t i = 0; i < num_trees; ++i) { // Max number of leaves processed so far size_t leaf_counter = ((*ptrees)[i] + i) / 2; - tree2fil_sparse( + tree2fil( *pnodes, (*ptrees)[i], model.trees[i], i, *params, vector_leaf, &leaf_counter, cat_sets); } From 97dc0499ca8ce6b4a47e2bc3ac3c76636bfc09b1 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 9 Nov 2021 20:19:04 -0800 Subject: [PATCH 02/24] unified init_dense, init_sparse --- cpp/src/fil/fil.cu | 82 +++++++++++++++++++--------------- cpp/src/fil/internal.cuh | 48 +++++++++----------- cpp/src/fil/treelite_import.cu | 10 ++--- cpp/test/sg/fil_test.cu | 16 +++---- 4 files changed, 80 insertions(+), 76 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 49c854f4e5..1d4feada5f 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -359,9 +359,11 @@ struct dense_forest : forest { } } + /// const int* trees is ignored void init(const raft::handle_t& h, const categorical_sets& cat_sets, const std::vector& vector_leaf, + const int* trees, const dense_node* nodes, const forest_params_t* params) { @@ -532,50 +534,58 @@ void check_params(const forest_params_t* params, bool dense) FIL_TPB); } -void init_dense(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const dense_node* nodes, - const forest_params_t* params) -{ - check_params(params, true); - dense_forest* f = new dense_forest(h); - f->init(h, cat_sets, vector_leaf, nodes, params); - *pf = f; -} +template +struct forest_from_node_t { + using T = sparse_forest; +}; + +template<> +struct forest_from_node_t { + using T = dense_forest; +}; +/** initializes a forest of any type +* When fil_node_t == dense_node, const int* trees is ignored +*/ template -void init_sparse(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const fil_node_t* nodes, - const forest_params_t* params) +void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const fil_node_t* nodes, + const forest_params_t* params) { - check_params(params, false); - sparse_forest* f = new sparse_forest(h); + check_params(params, is_dense()); + auto f = new typename forest_from_node_t::T(h); f->init(h, cat_sets, vector_leaf, trees, nodes, params); *pf = f; } // explicit instantiations for init_sparse() -template void init_sparse(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const sparse_node16* nodes, - const forest_params_t* params); - -template void init_sparse(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const sparse_node8* nodes, - const forest_params_t* params); +template void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const sparse_node16* nodes, + const forest_params_t* params); + +template void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const sparse_node8* nodes, + const forest_params_t* params); + +template void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const dense_node* nodes, + const forest_params_t* params); void free(const raft::handle_t& h, forest_t f) { diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index b8e05e4b51..7ea529037b 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -215,6 +215,14 @@ struct alignas(8) sparse_node8 : base_node { __host__ __device__ int left(int curr) const { return left_index(); } }; +struct dense_forest; + +template +constexpr bool is_dense() { + return std::is_same() + || std::is_same(); +} + /** leaf_algo_t describes what the leaves in a FIL forest store (predict) and how FIL aggregates them into class margins/regression result/best class **/ @@ -492,40 +500,26 @@ struct cat_sets_device_owner { } }; -/** init_dense uses params and nodes to initialize the dense forest stored in pf - * @param h cuML handle used by this function - * @param pf pointer to where to store the newly created forest - * @param nodes nodes for the forest, of length - (2**(params->depth + 1) - 1) * params->ntrees - * @param params pointer to parameters used to initialize the forest - * @param vector_leaf optional vector leaves - */ -void init_dense(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const dense_node* nodes, - const forest_params_t* params); - -/** init_sparse uses params, trees and nodes to initialize the sparse forest - * with sparse nodes stored in pf - * @tparam fil_node_t node type to use with the sparse forest; - * must be sparse_node16 or sparse_node8 +/** init uses params, trees and nodes to initialize the forest + * with nodes stored in pf + * @tparam fil_node_t node type to use with the forest; + * must be sparse_node16 or sparse_node8 or dense_node * @param h cuML handle used by this function * @param pf pointer to where to store the newly created forest * @param trees indices of tree roots in the nodes arrray, of length params->ntrees - * @param nodes nodes for the forest, of length params->num_nodes + * @param nodes nodes for the forest, of length params->num_nodes for sparse + or (2**(params->depth + 1) - 1) * params->ntrees for dense forests * @param params pointer to parameters used to initialize the forest * @param vector_leaf optional vector leaves */ template -void init_sparse(const raft::handle_t& h, - forest_t* pf, - const categorical_sets& cat_sets, - const std::vector& vector_leaf, - const int* trees, - const fil_node_t* nodes, - const forest_params_t* params); +void init(const raft::handle_t& h, + forest_t* pf, + const categorical_sets& cat_sets, + const std::vector& vector_leaf, + const int* trees, + const fil_node_t* nodes, + const forest_params_t* params); } // namespace fil diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 87fd29b9f9..f2a3c7bda9 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -357,7 +357,7 @@ int tree2fil(std::vector& nodes, // reserve space for child nodes // left is the offset of the left child node relative to the tree root // in the array of all nodes of the FIL sparse forest - int left = std::is_same() ? 2 * cur + 1 : sparse_index - root; + int left = is_dense() ? 2 * cur + 1 : sparse_index - root; sparse_index += 2; conversion_state cs = tl2fil_inner_node( left, tree, node_id, forest_params, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); @@ -705,8 +705,8 @@ void from_treelite(const raft::handle_t& handle, std::vector nodes; std::vector vector_leaf; tl2fil_dense(&nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init_dense(handle, pforest, cat_sets.accessor(), vector_leaf, nodes.data(), ¶ms); - // sync is necessary as nodes is used in init_dense(), + init(handle, pforest, cat_sets.accessor(), vector_leaf, nullptr, nodes.data(), ¶ms); + // sync is necessary as nodes is used in init(), // but destructed at the end of this function CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); if (tl_params->pforest_shape_str) { @@ -719,7 +719,7 @@ void from_treelite(const raft::handle_t& handle, std::vector nodes; std::vector vector_leaf; tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init_sparse( + init( handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); if (tl_params->pforest_shape_str) { @@ -732,7 +732,7 @@ void from_treelite(const raft::handle_t& handle, std::vector nodes; std::vector vector_leaf; tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init_sparse( + init( handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); if (tl_params->pforest_shape_str) { diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 293222667e..62652c4286 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -651,7 +651,7 @@ class PredictDenseFilTest : public BaseFilTest { fil_ps.threads_per_tree = ps.threads_per_tree; fil_ps.n_items = ps.n_items; - fil::init_dense(handle, pforest, cat_sets_h.accessor(), vector_leaf, nodes.data(), &fil_ps); + fil::init(handle, pforest, cat_sets_h.accessor(), vector_leaf, nullptr, nodes.data(), &fil_ps); } }; @@ -719,13 +719,13 @@ class BasePredictSparseFilTest : public BaseFilTest { dense2sparse(); fil_params.num_nodes = sparse_nodes.size(); - fil::init_sparse(handle, - pforest, - cat_sets_h.accessor(), - vector_leaf, - trees.data(), - sparse_nodes.data(), - &fil_params); + fil::init(handle, + pforest, + cat_sets_h.accessor(), + vector_leaf, + trees.data(), + sparse_nodes.data(), + &fil_params); } std::vector sparse_nodes; std::vector trees; From 84bdc1efbe7835c3257067b28802f700f7acb6cf Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Wed, 10 Nov 2021 22:20:06 -0800 Subject: [PATCH 03/24] drafted tl2fil as class for from_treelite --- cpp/src/fil/common.cuh | 10 ++ cpp/src/fil/fil.cu | 15 +-- cpp/src/fil/infer.cu | 2 +- cpp/src/fil/internal.cuh | 6 +- cpp/src/fil/treelite_import.cu | 168 +++++++++++++-------------------- 5 files changed, 91 insertions(+), 110 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 7de2eb8efd..6acd5fdac3 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -114,6 +114,16 @@ struct sparse_storage : storage_base { typedef sparse_storage sparse_storage16; typedef sparse_storage sparse_storage8; +template +struct node2storage { + using T = sparse_storage; +}; + +template <> +struct node2storage { + using T = dense_storage; +}; + /// all model parameters mostly required to compute shared memory footprint, /// also the footprint itself struct shmem_size_params { diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 1d4feada5f..90f44e35bd 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -378,6 +378,8 @@ struct dense_forest : forest { } else { transform_trees(nodes); } + printf( + "dense_forest::init nodes: GPU %p CPU %p API %p\n", nodes_.data(), h_nodes_.data(), nodes); CUDA_CHECK(cudaMemcpyAsync(nodes_.data(), h_nodes_.data(), num_nodes * sizeof(dense_node), @@ -391,6 +393,7 @@ struct dense_forest : forest { virtual void infer(predict_params params, cudaStream_t stream) override { + printf("dense_forest::infer nodes: GPU %p\n", nodes_.data()); dense_storage forest(cat_sets_.accessor(), vector_leaf_.data(), nodes_.data(), @@ -535,18 +538,18 @@ void check_params(const forest_params_t* params, bool dense) } template -struct forest_from_node_t { +struct node2forest { using T = sparse_forest; }; -template<> -struct forest_from_node_t { +template <> +struct node2forest { using T = dense_forest; }; /** initializes a forest of any type -* When fil_node_t == dense_node, const int* trees is ignored -*/ + * When fil_node_t == dense_node, const int* trees is ignored + */ template void init(const raft::handle_t& h, forest_t* pf, @@ -557,7 +560,7 @@ void init(const raft::handle_t& h, const forest_params_t* params) { check_params(params, is_dense()); - auto f = new typename forest_from_node_t::T(h); + auto f = new typename node2forest::T(h); f->init(h, cat_sets, vector_leaf, trees, nodes, params); *pf = f; } diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index 0f709db5ea..0b3673895f 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -36,7 +36,7 @@ #endif // __CUDA_ARCH__ #endif // CUDA_PRAGMA_UNROLL -#define INLINE_CONFIG __forceinline__ +#define INLINE_CONFIG __noinline__ namespace ML { namespace fil { diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 7ea529037b..d846f4e8d7 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -218,9 +218,9 @@ struct alignas(8) sparse_node8 : base_node { struct dense_forest; template -constexpr bool is_dense() { - return std::is_same() - || std::is_same(); +constexpr bool is_dense() +{ + return std::is_same() || std::is_same(); } /** leaf_algo_t describes what the leaves in a FIL forest store (predict) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index f2a3c7bda9..d77ed2e8a3 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -536,39 +536,6 @@ void tl2fil_common(forest_params_t* params, params->n_items = tl_params->n_items; } -// uses treelite model with additional tl_params to initialize FIL params -// and dense nodes (stored in *pnodes) -template -void tl2fil_dense(std::vector* pnodes, - forest_params_t* params, - const tl::ModelImpl& model, - const treelite_params_t* tl_params, - cat_sets_owner* cat_sets, - std::vector* vector_leaf) -{ - tl2fil_common(params, model, tl_params); - - // convert the nodes - int num_nodes = forest_num_nodes(params->num_trees, params->depth); - int max_leaves_per_tree = (tree_num_nodes(params->depth) + 1) / 2; - if (params->leaf_algo == VECTOR_LEAF) { - vector_leaf->resize(max_leaves_per_tree * params->num_trees * params->num_classes); - } - *cat_sets = allocate_cat_sets_owner(model); - pnodes->resize(num_nodes, dense_node()); - for (std::size_t i = 0; i < model.trees.size(); ++i) { - size_t leaf_counter = max_leaves_per_tree * i; - tree2fil(*pnodes, - i * tree_num_nodes(params->depth), - model.trees[i], - i, - *params, - vector_leaf, - &leaf_counter, - cat_sets); - } -} - template struct tl2fil_sparse_check_t { template @@ -580,6 +547,15 @@ struct tl2fil_sparse_check_t { } }; +template <> +struct tl2fil_sparse_check_t { + // no extra check for 16-byte sparse nodes + template + static void check(const tl::ModelImpl& model) + { + } +}; + template <> struct tl2fil_sparse_check_t { // no extra check for 16-byte sparse nodes @@ -618,48 +594,66 @@ struct tl2fil_sparse_check_t { } }; -// uses treelite model with additional tl_params to initialize FIL params, -// trees (stored in *ptrees) and sparse nodes (stored in *pnodes) template -void tl2fil_sparse(std::vector* ptrees, - std::vector* pnodes, - forest_params_t* params, - const tl::ModelImpl& model, - const treelite_params_t* tl_params, - cat_sets_owner* cat_sets, - std::vector* vector_leaf) -{ - tl2fil_common(params, model, tl_params); - tl2fil_sparse_check_t::check(model); +struct tl2fil_t { + std::vector trees; + std::vector nodes; + std::vector vector_leaf; + forest_params_t params; + cat_sets_owner cat_sets; + const tl::ModelImpl& model; + const treelite_params_t& tl_params; - size_t num_trees = model.trees.size(); + tl2fil_t(const tl::ModelImpl& model_, const treelite_params_t& tl_params_) + : model(model_), tl_params(tl_params_) + { + tl2fil_common(¶ms, model, &tl_params); + tl2fil_sparse_check_t::check(model); - ptrees->reserve(num_trees); - ptrees->push_back(0); - for (size_t i = 0; i < num_trees - 1; ++i) { - ptrees->push_back(model.trees[i].num_nodes + ptrees->back()); - } - size_t total_nodes = ptrees->back() + model.trees.back().num_nodes; + size_t num_trees = model.trees.size(); - if (params->leaf_algo == VECTOR_LEAF) { - size_t max_leaves = (total_nodes + num_trees) / 2; - vector_leaf->resize(max_leaves * params->num_classes); - } + trees.reserve(num_trees); + trees.push_back(0); + for (size_t i = 0; i < num_trees; ++i) { + int num_nodes = is_dense() ? forest_num_nodes(params.num_trees, params.depth) + : model.trees[i].num_nodes; + trees.push_back(num_nodes + trees.back()); + } + size_t total_nodes = trees.back(); + trees.pop_back(); + + if (params.leaf_algo == VECTOR_LEAF) { + size_t max_leaves = is_dense() + ? num_trees * (tree_num_nodes(params.depth) + 1) / 2 + : (total_nodes + num_trees) / 2; + vector_leaf.resize(max_leaves * params.num_classes); + } - *cat_sets = allocate_cat_sets_owner(model); - pnodes->resize(total_nodes); + cat_sets = allocate_cat_sets_owner(model); + nodes.resize(total_nodes); // convert the nodes #pragma omp parallel for - for (std::size_t i = 0; i < num_trees; ++i) { - // Max number of leaves processed so far - size_t leaf_counter = ((*ptrees)[i] + i) / 2; - tree2fil( - *pnodes, (*ptrees)[i], model.trees[i], i, *params, vector_leaf, &leaf_counter, cat_sets); + for (std::size_t i = 0; i < num_trees; ++i) { + // Max number of leaves processed so far + size_t leaf_counter = (trees[i] + i) / 2; + tree2fil(nodes, trees[i], model.trees[i], i, params, &vector_leaf, &leaf_counter, &cat_sets); + } + + params.num_nodes = nodes.size(); } - params->num_nodes = pnodes->size(); -} + void init_GPU(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) + { + init(handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); + // sync is necessary as nodes are used in init(), + // but destructed at the end of this function + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); + if (tl_params.pforest_shape_str) { + *tl_params.pforest_shape_str = sprintf_shape(model, storage_type, nodes, trees, cat_sets); + } + } +}; template void from_treelite(const raft::handle_t& handle, @@ -698,46 +692,20 @@ void from_treelite(const raft::handle_t& handle, } } - forest_params_t params; - cat_sets_owner cat_sets; switch (storage_type) { case storage_type_t::DENSE: { - std::vector nodes; - std::vector vector_leaf; - tl2fil_dense(&nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init(handle, pforest, cat_sets.accessor(), vector_leaf, nullptr, nodes.data(), ¶ms); - // sync is necessary as nodes is used in init(), - // but destructed at the end of this function - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - if (tl_params->pforest_shape_str) { - *tl_params->pforest_shape_str = sprintf_shape(model, storage_type, nodes, {}, cat_sets); - } + tl2fil_t tl2fil(model, *tl_params); + tl2fil.init_GPU(handle, pforest, storage_type); break; } case storage_type_t::SPARSE: { - std::vector trees; - std::vector nodes; - std::vector vector_leaf; - tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init( - handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - if (tl_params->pforest_shape_str) { - *tl_params->pforest_shape_str = sprintf_shape(model, storage_type, nodes, trees, cat_sets); - } + tl2fil_t tl2fil(model, *tl_params); + tl2fil.init_GPU(handle, pforest, storage_type); break; } case storage_type_t::SPARSE8: { - std::vector trees; - std::vector nodes; - std::vector vector_leaf; - tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params, &cat_sets, &vector_leaf); - init( - handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - if (tl_params->pforest_shape_str) { - *tl_params->pforest_shape_str = sprintf_shape(model, storage_type, nodes, trees, cat_sets); - } + tl2fil_t tl2fil(model, *tl_params); + tl2fil.init_GPU(handle, pforest, storage_type); break; } default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); @@ -759,7 +727,7 @@ void from_treelite(const raft::handle_t& handle, // allocates caller-owned char* using malloc() template char* sprintf_shape(const tl::ModelImpl& model, - storage_type_t storage, + storage_type_t storage_type, const std::vector& nodes, const std::vector& trees, const cat_sets_owner cat_sets) @@ -768,8 +736,8 @@ char* sprintf_shape(const tl::ModelImpl& model, double size_mb = (trees.size() * sizeof(trees.front()) + nodes.size() * sizeof(nodes.front()) + cat_sets.bits.size()) / 1e6; - forest_shape << storage_type_repr[storage] << " model size " << std::setprecision(2) << size_mb - << " MB" << std::endl; + forest_shape << storage_type_repr[storage_type] << " model size " << std::setprecision(2) + << size_mb << " MB" << std::endl; if (cat_sets.bits.size() > 0) { forest_shape << "number of categorical nodes for each feature id: {"; std::size_t total_cat_nodes = 0; From c1b5d540a312a9c1cf0837f0b663a532c74c1f16 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 11 Nov 2021 00:18:20 -0800 Subject: [PATCH 04/24] fixed a bug --- cpp/src/fil/fil.cu | 3 --- cpp/src/fil/infer.cu | 2 +- cpp/src/fil/treelite_import.cu | 6 +++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 90f44e35bd..b0f6eebc0d 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -378,8 +378,6 @@ struct dense_forest : forest { } else { transform_trees(nodes); } - printf( - "dense_forest::init nodes: GPU %p CPU %p API %p\n", nodes_.data(), h_nodes_.data(), nodes); CUDA_CHECK(cudaMemcpyAsync(nodes_.data(), h_nodes_.data(), num_nodes * sizeof(dense_node), @@ -393,7 +391,6 @@ struct dense_forest : forest { virtual void infer(predict_params params, cudaStream_t stream) override { - printf("dense_forest::infer nodes: GPU %p\n", nodes_.data()); dense_storage forest(cat_sets_.accessor(), vector_leaf_.data(), nodes_.data(), diff --git a/cpp/src/fil/infer.cu b/cpp/src/fil/infer.cu index 0b3673895f..0f709db5ea 100644 --- a/cpp/src/fil/infer.cu +++ b/cpp/src/fil/infer.cu @@ -36,7 +36,7 @@ #endif // __CUDA_ARCH__ #endif // CUDA_PRAGMA_UNROLL -#define INLINE_CONFIG __noinline__ +#define INLINE_CONFIG __forceinline__ namespace ML { namespace fil { diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index d77ed2e8a3..57dd227346 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -612,11 +612,11 @@ struct tl2fil_t { size_t num_trees = model.trees.size(); - trees.reserve(num_trees); + trees.reserve(num_trees + 1); trees.push_back(0); for (size_t i = 0; i < num_trees; ++i) { - int num_nodes = is_dense() ? forest_num_nodes(params.num_trees, params.depth) - : model.trees[i].num_nodes; + int num_nodes = + is_dense() ? tree_num_nodes(params.depth) : model.trees[i].num_nodes; trees.push_back(num_nodes + trees.back()); } size_t total_nodes = trees.back(); From 6e8e463df66cd92b47844f8b8f1e1bf5dd56c0ad Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 11 Nov 2021 00:24:43 -0800 Subject: [PATCH 05/24] stray changes --- cpp/src/fil/treelite_import.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 57dd227346..bcfbd51aad 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -727,7 +727,7 @@ void from_treelite(const raft::handle_t& handle, // allocates caller-owned char* using malloc() template char* sprintf_shape(const tl::ModelImpl& model, - storage_type_t storage_type, + storage_type_t storage, const std::vector& nodes, const std::vector& trees, const cat_sets_owner cat_sets) @@ -736,8 +736,8 @@ char* sprintf_shape(const tl::ModelImpl& model, double size_mb = (trees.size() * sizeof(trees.front()) + nodes.size() * sizeof(nodes.front()) + cat_sets.bits.size()) / 1e6; - forest_shape << storage_type_repr[storage_type] << " model size " << std::setprecision(2) - << size_mb << " MB" << std::endl; + forest_shape << storage_type_repr[storage] << " model size " << std::setprecision(2) << size_mb + << " MB" << std::endl; if (cat_sets.bits.size() > 0) { forest_shape << "number of categorical nodes for each feature id: {"; std::size_t total_cat_nodes = 0; From d6d0ece851eab0f1059124c52c2a2f67e1f19a5f Mon Sep 17 00:00:00 2001 From: Levs Dolgovs <36520083+levsnv@users.noreply.github.com> Date: Thu, 11 Nov 2021 19:06:08 -0800 Subject: [PATCH 06/24] Apply suggestions from code review Co-authored-by: Andy Adinets --- cpp/src/fil/fil.cu | 2 +- cpp/src/fil/internal.cuh | 4 ++-- cpp/src/fil/treelite_import.cu | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index b0f6eebc0d..a02fffbe20 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -359,7 +359,7 @@ struct dense_forest : forest { } } - /// const int* trees is ignored + /// const int* trees is ignored and only provided for compatibility with sparse_forest::init() void init(const raft::handle_t& h, const categorical_sets& cat_sets, const std::vector& vector_leaf, diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index d846f4e8d7..0778354ebb 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -503,10 +503,10 @@ struct cat_sets_device_owner { /** init uses params, trees and nodes to initialize the forest * with nodes stored in pf * @tparam fil_node_t node type to use with the forest; - * must be sparse_node16 or sparse_node8 or dense_node + * must be sparse_node16, sparse_node8 or dense_node * @param h cuML handle used by this function * @param pf pointer to where to store the newly created forest - * @param trees indices of tree roots in the nodes arrray, of length params->ntrees + * @param trees for sparse forests, indices of tree roots in the nodes arrray, of length params->ntrees; ignored for dense forests * @param nodes nodes for the forest, of length params->num_nodes for sparse or (2**(params->depth + 1) - 1) * params->ntrees for dense forests * @param params pointer to parameters used to initialize the forest diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index bcfbd51aad..1bf614b1e4 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -329,7 +329,7 @@ conversion_state tl2fil_inner_node(int fil_left_child, } else { ASSERT(false, "only numerical and categorical split nodes are supported"); } - fil_node_t node({}, split, feature_id, default_left, false, is_categorical, fil_left_child); + fil_node_t node(val_t{}, split, feature_id, default_left, false, is_categorical, fil_left_child); return conversion_state{node, tl_left, tl_right}; } @@ -549,7 +549,7 @@ struct tl2fil_sparse_check_t { template <> struct tl2fil_sparse_check_t { - // no extra check for 16-byte sparse nodes + // no extra check for dense nodes template static void check(const tl::ModelImpl& model) { From 42c1c2ea659af475f0e38b866d9522b5f6162336 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Mon, 15 Nov 2021 17:23:22 -0800 Subject: [PATCH 07/24] apply suggestions from code review --- cpp/src/fil/common.cuh | 27 ++++++++-- cpp/src/fil/treelite_import.cu | 95 ++++++++++++---------------------- 2 files changed, 54 insertions(+), 68 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 6acd5fdac3..8b693a39ea 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -28,6 +28,11 @@ #include "internal.cuh" +namespace treelite { + template + struct ModelImpl; +} + namespace ML { namespace fil { @@ -114,14 +119,26 @@ struct sparse_storage : storage_base { typedef sparse_storage sparse_storage16; typedef sparse_storage sparse_storage8; -template -struct node2storage { - using T = sparse_storage; +struct dense_forest; +template +struct sparse_forest; + +template +struct node_traits { + using storage = sparse_storage; + using forest = sparse_forest; + static const bool IS_DENSE = false; + template + static void check(const treelite::ModelImpl& model); }; template <> -struct node2storage { - using T = dense_storage; +struct node_traits { + using storage = dense_storage; + using forest = dense_forest; + static const bool IS_DENSE = true; + template + static void check(const treelite::ModelImpl& model) {} }; /// all model parameters mostly required to compute shared memory footprint, diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index bcfbd51aad..c4a094ea4b 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -357,7 +357,7 @@ int tree2fil(std::vector& nodes, // reserve space for child nodes // left is the offset of the left child node relative to the tree root // in the array of all nodes of the FIL sparse forest - int left = is_dense() ? 2 * cur + 1 : sparse_index - root; + int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; sparse_index += 2; conversion_state cs = tl2fil_inner_node( left, tree, node_id, forest_params, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); @@ -536,63 +536,32 @@ void tl2fil_common(forest_params_t* params, params->n_items = tl_params->n_items; } -template -struct tl2fil_sparse_check_t { - template - static void check(const tl::ModelImpl& model) - { - ASSERT(false, - "internal error: " - "only a specialization of this template should be used"); - } -}; - -template <> -struct tl2fil_sparse_check_t { - // no extra check for 16-byte sparse nodes - template - static void check(const tl::ModelImpl& model) - { - } -}; - -template <> -struct tl2fil_sparse_check_t { - // no extra check for 16-byte sparse nodes - template - static void check(const tl::ModelImpl& model) - { - } -}; - -template <> -struct tl2fil_sparse_check_t { - static const int MAX_FEATURES = 1 << sparse_node8::FID_NUM_BITS; - static const int MAX_TREE_NODES = (1 << sparse_node8::LEFT_NUM_BITS) - 1; - template - static void check(const tl::ModelImpl& model) - { - // check the number of features - int num_features = model.num_feature; - ASSERT(num_features <= MAX_FEATURES, - "model has %d features, " +template +template +void node_traits::check(const treelite::ModelImpl& model) { + if constexpr(!std::is_same()) return; + const int MAX_FEATURES = 1 << sparse_node8::FID_NUM_BITS; + const int MAX_TREE_NODES = (1 << sparse_node8::LEFT_NUM_BITS) - 1; + // check the number of features + int num_features = model.num_feature; + ASSERT(num_features <= MAX_FEATURES, + "model has %d features, " + "but only %d supported for 8-byte sparse nodes", + num_features, + MAX_FEATURES); + + // check the number of tree nodes + const std::vector>& trees = model.trees; + for (std::size_t i = 0; i < trees.size(); ++i) { + int num_nodes = trees[i].num_nodes; + ASSERT(num_nodes <= MAX_TREE_NODES, + "tree %zu has %d nodes, " "but only %d supported for 8-byte sparse nodes", - num_features, - MAX_FEATURES); - - // check the number of tree nodes - const std::vector>& trees = model.trees; - for (std::size_t i = 0; i < trees.size(); ++i) { - int num_nodes = trees[i].num_nodes; - ASSERT(num_nodes <= MAX_TREE_NODES, - "tree %zu has %d nodes, " - "but only %d supported for 8-byte sparse nodes", - i, - num_nodes, - MAX_TREE_NODES); - } + i, + num_nodes, + MAX_TREE_NODES); } -}; +} template struct tl2fil_t { @@ -608,7 +577,7 @@ struct tl2fil_t { : model(model_), tl_params(tl_params_) { tl2fil_common(¶ms, model, &tl_params); - tl2fil_sparse_check_t::check(model); + node_traits::check(model); size_t num_trees = model.trees.size(); @@ -616,14 +585,14 @@ struct tl2fil_t { trees.push_back(0); for (size_t i = 0; i < num_trees; ++i) { int num_nodes = - is_dense() ? tree_num_nodes(params.depth) : model.trees[i].num_nodes; + node_traits::IS_DENSE ? tree_num_nodes(params.depth) : model.trees[i].num_nodes; trees.push_back(num_nodes + trees.back()); } size_t total_nodes = trees.back(); trees.pop_back(); if (params.leaf_algo == VECTOR_LEAF) { - size_t max_leaves = is_dense() + size_t max_leaves = node_traits::IS_DENSE ? num_trees * (tree_num_nodes(params.depth) + 1) / 2 : (total_nodes + num_trees) / 2; vector_leaf.resize(max_leaves * params.num_classes); @@ -643,7 +612,7 @@ struct tl2fil_t { params.num_nodes = nodes.size(); } - void init_GPU(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) + void init_gpu(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) { init(handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); // sync is necessary as nodes are used in init(), @@ -695,17 +664,17 @@ void from_treelite(const raft::handle_t& handle, switch (storage_type) { case storage_type_t::DENSE: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_GPU(handle, pforest, storage_type); + tl2fil.init_gpu(handle, pforest, storage_type); break; } case storage_type_t::SPARSE: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_GPU(handle, pforest, storage_type); + tl2fil.init_gpu(handle, pforest, storage_type); break; } case storage_type_t::SPARSE8: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_GPU(handle, pforest, storage_type); + tl2fil.init_gpu(handle, pforest, storage_type); break; } default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); From 6967714205114b2544975deb422decab6f7248ce Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 16 Nov 2021 18:17:30 -0800 Subject: [PATCH 08/24] style --- cpp/src/fil/common.cuh | 16 +++++++++------- cpp/src/fil/fil.cu | 3 ++- cpp/src/fil/internal.cuh | 3 ++- cpp/src/fil/treelite_import.cu | 5 +++-- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 34257740a0..5b187567fc 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -29,8 +29,8 @@ #include "internal.cuh" namespace treelite { - template - struct ModelImpl; +template +struct ModelImpl; } namespace ML { @@ -125,8 +125,8 @@ struct sparse_forest; template struct node_traits { - using storage = sparse_storage; - using forest = sparse_forest; + using storage = sparse_storage; + using forest = sparse_forest; static const bool IS_DENSE = false; template static void check(const treelite::ModelImpl& model); @@ -134,11 +134,13 @@ struct node_traits { template <> struct node_traits { - using storage = dense_storage; - using forest = dense_forest; + using storage = dense_storage; + using forest = dense_forest; static const bool IS_DENSE = true; template - static void check(const treelite::ModelImpl& model) {} + static void check(const treelite::ModelImpl& model) + { + } }; /// all model parameters mostly required to compute shared memory footprint, diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 4df7e86a09..f848ffbcad 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -379,7 +379,8 @@ struct dense_forest : forest { } } - /// const int* trees is ignored and only provided for compatibility with sparse_forest::init() + /// const int* trees is ignored and only provided for compatibility with + /// sparse_forest::init() void init(const raft::handle_t& h, const categorical_sets& cat_sets, const std::vector& vector_leaf, diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 5bd52eeb71..0f3c3b3a0e 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -513,7 +513,8 @@ struct cat_sets_device_owner { * must be sparse_node16, sparse_node8 or dense_node * @param h cuML handle used by this function * @param pf pointer to where to store the newly created forest - * @param trees for sparse forests, indices of tree roots in the nodes arrray, of length params->ntrees; ignored for dense forests + * @param trees for sparse forests, indices of tree roots in the nodes arrray, of length + params->ntrees; ignored for dense forests * @param nodes nodes for the forest, of length params->num_nodes for sparse or (2**(params->depth + 1) - 1) * params->ntrees for dense forests * @param params pointer to parameters used to initialize the forest diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 5371bd821e..f242d98a06 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -545,8 +545,9 @@ void tl2fil_common(forest_params_t* params, template template -void node_traits::check(const treelite::ModelImpl& model) { - if constexpr(!std::is_same()) return; +void node_traits::check(const treelite::ModelImpl& model) +{ + if constexpr (!std::is_same()) return; const int MAX_FEATURES = 1 << sparse_node8::FID_NUM_BITS; const int MAX_TREE_NODES = (1 << sparse_node8::LEFT_NUM_BITS) - 1; // check the number of features From dfdfff1737d32949e46a89f8d1bc91d7740c2478 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 16 Nov 2021 20:14:22 -0800 Subject: [PATCH 09/24] made tree2fil a method of tl2fil_t, misc comments --- cpp/src/fil/fil.cu | 15 +---- cpp/src/fil/internal.cuh | 10 +-- cpp/src/fil/treelite_import.cu | 108 ++++++++++++++++----------------- 3 files changed, 55 insertions(+), 78 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index f848ffbcad..92d230d0da 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -563,16 +563,6 @@ void check_params(const forest_params_t* params, bool dense) FIL_TPB); } -template -struct node2forest { - using T = sparse_forest; -}; - -template <> -struct node2forest { - using T = dense_forest; -}; - /** initializes a forest of any type * When fil_node_t == dense_node, const int* trees is ignored */ @@ -585,8 +575,9 @@ void init(const raft::handle_t& h, const fil_node_t* nodes, const forest_params_t* params) { - check_params(params, is_dense()); - auto f = new typename node2forest::T(h); + check_params(params, node_traits::IS_DENSE); + using forest_type = typename node_traits::forest; + forest_type* f = new forest_type(h); f->init(h, cat_sets, vector_leaf, trees, nodes, params); *pf = f; } diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 0f3c3b3a0e..885a9772d2 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -149,6 +149,7 @@ __host__ __device__ __forceinline__ val_t base_node::output() const /** dense_node is a single node of a dense forest */ struct alignas(8) dense_node : base_node { dense_node() = default; + /// ignoring left_index, this is useful to unify import from treelite dense_node(val_t output, val_t split, int fid, @@ -158,7 +159,6 @@ struct alignas(8) dense_node : base_node { int left_index = -1) : base_node(output, split, fid, def_left, is_leaf, is_categorical) { - // ignoring left_index, this is useful to unify import from treelite } /** index of the left child, where curr is the index of the current node */ __host__ __device__ int left(int curr) const { return 2 * curr + 1; } @@ -215,14 +215,6 @@ struct alignas(8) sparse_node8 : base_node { __host__ __device__ int left(int curr) const { return left_index(); } }; -struct dense_forest; - -template -constexpr bool is_dense() -{ - return std::is_same() || std::is_same(); -} - /** leaf_algo_t describes what the leaves in a FIL forest store (predict) and how FIL aggregates them into class margins/regression result/best class **/ diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index f242d98a06..d1a960c245 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -301,7 +301,6 @@ template conversion_state tl2fil_inner_node(int fil_left_child, const tl::Tree& tree, int tl_node_id, - const forest_params_t& forest_params, cat_sets_owner* cat_sets, std::size_t* bit_pool_offset) { @@ -340,51 +339,6 @@ conversion_state tl2fil_inner_node(int fil_left_child, return conversion_state{node, tl_left, tl_right}; } -template -int tree2fil(std::vector& nodes, - int root, - const tl::Tree& tree, - std::size_t tree_idx, - const forest_params_t& forest_params, - std::vector* vector_leaf, - std::size_t* leaf_counter, - cat_sets_owner* cat_sets) -{ - typedef std::pair pair_t; - std::stack stack; - int sparse_index = root + 1; - stack.push(pair_t(tree_root(tree), 0)); - while (!stack.empty()) { - const pair_t& top = stack.top(); - int node_id = top.first; - int cur = top.second; - stack.pop(); - - while (!tree.IsLeaf(node_id)) { - // reserve space for child nodes - // left is the offset of the left child node relative to the tree root - // in the array of all nodes of the FIL sparse forest - int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; - sparse_index += 2; - conversion_state cs = tl2fil_inner_node( - left, tree, node_id, forest_params, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); - nodes[root + cur] = cs.node; - // push child nodes into the stack - stack.push(pair_t(cs.tl_right, left + 1)); - // stack.push(pair_t(tl_left, left)); - node_id = cs.tl_left; - cur = left; - } - - // leaf node - nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0); - tl2fil_leaf_payload( - &nodes[root + cur], root + cur, tree, node_id, forest_params, vector_leaf, leaf_counter); - } - - return root; -} - struct level_entry { int n_branch_nodes, n_leaves; }; @@ -573,7 +527,7 @@ void node_traits::check(const treelite::ModelImpl& template struct tl2fil_t { - std::vector trees; + std::vector roots; std::vector nodes; std::vector vector_leaf; forest_params_t params; @@ -589,15 +543,15 @@ struct tl2fil_t { size_t num_trees = model.trees.size(); - trees.reserve(num_trees + 1); - trees.push_back(0); + roots.reserve(num_trees + 1); + roots.push_back(0); for (size_t i = 0; i < num_trees; ++i) { int num_nodes = node_traits::IS_DENSE ? tree_num_nodes(params.depth) : model.trees[i].num_nodes; - trees.push_back(num_nodes + trees.back()); + roots.push_back(num_nodes + roots.back()); } - size_t total_nodes = trees.back(); - trees.pop_back(); + size_t total_nodes = roots.back(); + roots.pop_back(); if (params.leaf_algo == VECTOR_LEAF) { size_t max_leaves = node_traits::IS_DENSE @@ -611,23 +565,63 @@ struct tl2fil_t { // convert the nodes #pragma omp parallel for - for (std::size_t i = 0; i < num_trees; ++i) { + for (std::size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { // Max number of leaves processed so far - size_t leaf_counter = (trees[i] + i) / 2; - tree2fil(nodes, trees[i], model.trees[i], i, params, &vector_leaf, &leaf_counter, &cat_sets); + size_t leaf_counter = (roots[tree_idx] + tree_idx) / 2; + tree2fil(roots[tree_idx], model.trees[tree_idx], tree_idx, &leaf_counter); } params.num_nodes = nodes.size(); } + int tree2fil(int root, + const tl::Tree& tree, + std::size_t tree_idx, + std::size_t* leaf_counter) + { + typedef std::pair pair_t; + std::stack stack; + int sparse_index = root + 1; + stack.push(pair_t(tree_root(tree), 0)); + while (!stack.empty()) { + const pair_t& top = stack.top(); + int node_id = top.first; + int cur = top.second; + stack.pop(); + + while (!tree.IsLeaf(node_id)) { + // reserve space for child nodes + // left is the offset of the left child node relative to the tree root + // in the array of all nodes of the FIL sparse forest + int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; + sparse_index += 2; + conversion_state cs = tl2fil_inner_node( + left, tree, node_id, &cat_sets, &cat_sets.bit_pool_offsets[tree_idx]); + nodes[root + cur] = cs.node; + // push child nodes into the stack + stack.push(pair_t(cs.tl_right, left + 1)); + // stack.push(pair_t(tl_left, left)); + node_id = cs.tl_left; + cur = left; + } + + // leaf node + nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0); + tl2fil_leaf_payload( + &nodes[root + cur], root + cur, tree, node_id, params, &vector_leaf, leaf_counter); + } + + return root; + } + void init_gpu(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) { - init(handle, pforest, cat_sets.accessor(), vector_leaf, trees.data(), nodes.data(), ¶ms); + init(handle, pforest, cat_sets.accessor(), vector_leaf, roots.data(), nodes.data(), ¶ms); // sync is necessary as nodes are used in init(), // but destructed at the end of this function CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); if (tl_params.pforest_shape_str) { - *tl_params.pforest_shape_str = sprintf_shape(model, storage_type, nodes, trees, cat_sets); + *tl_params.pforest_shape_str = sprintf_shape(model, storage_type, nodes, roots, cat_sets); } } }; From 58044af34f8d8a3d693356d9f816629beaa9fbd5 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 16 Nov 2021 21:45:50 -0800 Subject: [PATCH 10/24] tl2fil_t:: init_object(), init_forest() --- cpp/src/fil/treelite_import.cu | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index d1a960c245..c909e1ae19 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -537,6 +537,10 @@ struct tl2fil_t { tl2fil_t(const tl::ModelImpl& model_, const treelite_params_t& tl_params_) : model(model_), tl_params(tl_params_) + { + } + + void init_object() { tl2fil_common(¶ms, model, &tl_params); node_traits::check(model); @@ -614,7 +618,7 @@ struct tl2fil_t { return root; } - void init_gpu(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) + void init_forest(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) { init(handle, pforest, cat_sets.accessor(), vector_leaf, roots.data(), nodes.data(), ¶ms); // sync is necessary as nodes are used in init(), @@ -666,17 +670,20 @@ void from_treelite(const raft::handle_t& handle, switch (storage_type) { case storage_type_t::DENSE: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_gpu(handle, pforest, storage_type); + tl2fil.init_object(); + tl2fil.init_forest(handle, pforest, storage_type); break; } case storage_type_t::SPARSE: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_gpu(handle, pforest, storage_type); + tl2fil.init_object(); + tl2fil.init_forest(handle, pforest, storage_type); break; } case storage_type_t::SPARSE8: { tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_gpu(handle, pforest, storage_type); + tl2fil.init_object(); + tl2fil.init_forest(handle, pforest, storage_type); break; } default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); From 8377aeee29eeb356483441c995e38a083f93fbc3 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 16 Nov 2021 23:29:56 -0800 Subject: [PATCH 11/24] tracking tokens --- cpp/src/fil/treelite_import.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index c909e1ae19..7dbaf28268 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -17,7 +17,7 @@ /** @file treelite_import.cu converts from treelite format to a FIL-centric CPU-RAM format, so that * fil.cu can make a `forest` object out of it. */ -#include "common.cuh" // for num_trees, tree_num_nodes +#include "common.cuh" // for node_traits, num_trees, tree_num_nodes #include "internal.cuh" // for MAX_FIL_INT_FLOAT, BITS_PER_BYTE, cat_feature_counters, cat_sets, cat_sets_owner, categorical_sets, leaf_algo_t #include // for algo_t, from_treelite, storage_type_repr, storage_type_t, treelite_params_t @@ -30,7 +30,7 @@ #include // for Operator, SplitFeatureType, kGE, kGT, kLE, kLT, kNumerical #include // for ModelHandle -#include // for Tree +#include // for Tree, Model, ModelImpl, ModelParam #include // for omp From e1702905620b8f0acb320d98646b9ab3c10557a0 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Wed, 17 Nov 2021 16:29:54 -0800 Subject: [PATCH 12/24] addressed review comments --- cpp/src/fil/common.cuh | 1 + cpp/src/fil/treelite_import.cu | 192 ++++++++++++++++++--------------- 2 files changed, 106 insertions(+), 87 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 5b187567fc..a7485e38ed 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -28,6 +28,7 @@ #include "internal.cuh" +// needed for node_traits<...> namespace treelite { template struct ModelImpl; diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 7dbaf28268..3af8e81e8b 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -339,6 +339,51 @@ conversion_state tl2fil_inner_node(int fil_left_child, return conversion_state{node, tl_left, tl_right}; } +template +int tree2fil(std::vector& nodes, + int root, + const tl::Tree& tree, + std::size_t tree_idx, + const forest_params_t& forest_params, + std::vector* vector_leaf, + std::size_t* leaf_counter, + cat_sets_owner* cat_sets) +{ + typedef std::pair pair_t; + std::stack stack; + int sparse_index = root + 1; + stack.push(pair_t(tree_root(tree), 0)); + while (!stack.empty()) { + const pair_t& top = stack.top(); + int node_id = top.first; + int cur = top.second; + stack.pop(); + + while (!tree.IsLeaf(node_id)) { + // reserve space for child nodes + // left is the offset of the left child node relative to the tree root + // in the array of all nodes of the FIL sparse forest + int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; + sparse_index += 2; + conversion_state cs = tl2fil_inner_node( + left, tree, node_id, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); + nodes[root + cur] = cs.node; + // push child nodes into the stack + stack.push(pair_t(cs.tl_right, left + 1)); + // stack.push(pair_t(tl_left, left)); + node_id = cs.tl_left; + cur = left; + } + + // leaf node + nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0); + tl2fil_leaf_payload( + &nodes[root + cur], root + cur, tree, node_id, forest_params, vector_leaf, leaf_counter); + } + + return root; +} + struct level_entry { int n_branch_nodes, n_leaves; }; @@ -527,109 +572,91 @@ void node_traits::check(const treelite::ModelImpl& template struct tl2fil_t { - std::vector roots; - std::vector nodes; - std::vector vector_leaf; - forest_params_t params; - cat_sets_owner cat_sets; - const tl::ModelImpl& model; - const treelite_params_t& tl_params; + std::vector roots_; + std::vector nodes_; + std::vector vector_leaf_; + forest_params_t params_; + cat_sets_owner cat_sets_; + const tl::ModelImpl& model_; + const treelite_params_t& tl_params_; tl2fil_t(const tl::ModelImpl& model_, const treelite_params_t& tl_params_) - : model(model_), tl_params(tl_params_) + : model_(model_), tl_params_(tl_params_) { } - void init_object() + void init() { - tl2fil_common(¶ms, model, &tl_params); - node_traits::check(model); + tl2fil_common(¶ms_, model_, &tl_params_); + node_traits::check(model_); - size_t num_trees = model.trees.size(); + size_t num_trees = model_.trees.size(); - roots.reserve(num_trees + 1); - roots.push_back(0); + roots_.reserve(num_trees + 1); + roots_.push_back(0); for (size_t i = 0; i < num_trees; ++i) { - int num_nodes = - node_traits::IS_DENSE ? tree_num_nodes(params.depth) : model.trees[i].num_nodes; - roots.push_back(num_nodes + roots.back()); + int num_nodes = node_traits::IS_DENSE ? tree_num_nodes(params_.depth) + : model_.trees[i].num_nodes; + roots_.push_back(num_nodes + roots_.back()); } - size_t total_nodes = roots.back(); - roots.pop_back(); + size_t total_nodes = roots_.back(); + roots_.pop_back(); - if (params.leaf_algo == VECTOR_LEAF) { + if (params_.leaf_algo == VECTOR_LEAF) { size_t max_leaves = node_traits::IS_DENSE - ? num_trees * (tree_num_nodes(params.depth) + 1) / 2 + ? num_trees * (tree_num_nodes(params_.depth) + 1) / 2 : (total_nodes + num_trees) / 2; - vector_leaf.resize(max_leaves * params.num_classes); + vector_leaf_.resize(max_leaves * params_.num_classes); } - cat_sets = allocate_cat_sets_owner(model); - nodes.resize(total_nodes); + cat_sets_ = allocate_cat_sets_owner(model_); + nodes_.resize(total_nodes); -// convert the nodes +// convert the nodes_ #pragma omp parallel for for (std::size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { // Max number of leaves processed so far - size_t leaf_counter = (roots[tree_idx] + tree_idx) / 2; - tree2fil(roots[tree_idx], model.trees[tree_idx], tree_idx, &leaf_counter); - } - - params.num_nodes = nodes.size(); - } - - int tree2fil(int root, - const tl::Tree& tree, - std::size_t tree_idx, - std::size_t* leaf_counter) - { - typedef std::pair pair_t; - std::stack stack; - int sparse_index = root + 1; - stack.push(pair_t(tree_root(tree), 0)); - while (!stack.empty()) { - const pair_t& top = stack.top(); - int node_id = top.first; - int cur = top.second; - stack.pop(); - - while (!tree.IsLeaf(node_id)) { - // reserve space for child nodes - // left is the offset of the left child node relative to the tree root - // in the array of all nodes of the FIL sparse forest - int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; - sparse_index += 2; - conversion_state cs = tl2fil_inner_node( - left, tree, node_id, &cat_sets, &cat_sets.bit_pool_offsets[tree_idx]); - nodes[root + cur] = cs.node; - // push child nodes into the stack - stack.push(pair_t(cs.tl_right, left + 1)); - // stack.push(pair_t(tl_left, left)); - node_id = cs.tl_left; - cur = left; - } - - // leaf node - nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0); - tl2fil_leaf_payload( - &nodes[root + cur], root + cur, tree, node_id, params, &vector_leaf, leaf_counter); + size_t leaf_counter = (roots_[tree_idx] + tree_idx) / 2; + tree2fil(nodes_, + roots_[tree_idx], + model_.trees[tree_idx], + tree_idx, + params_, + &vector_leaf_, + &leaf_counter, + &cat_sets_); } - return root; + params_.num_nodes = nodes_.size(); } + /// initializes FIL forest object, to be ready to infer void init_forest(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) { - init(handle, pforest, cat_sets.accessor(), vector_leaf, roots.data(), nodes.data(), ¶ms); - // sync is necessary as nodes are used in init(), + ML::fil::init( + handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), ¶ms_); + // sync is necessary as nodes_ are used in init(), // but destructed at the end of this function CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - if (tl_params.pforest_shape_str) { - *tl_params.pforest_shape_str = sprintf_shape(model, storage_type, nodes, roots, cat_sets); + if (tl_params_.pforest_shape_str) { + *tl_params_.pforest_shape_str = + sprintf_shape(model_, storage_type, nodes_, roots_, cat_sets_); } } }; +template +void convert(const tl::ModelImpl& model, + const treelite_params_t& tl_params, + const raft::handle_t& handle, + forest_t* pforest, + storage_type_t storage_type) +{ + tl2fil_t tl2fil(model, tl_params); + tl2fil.init(); + tl2fil.init_forest(handle, pforest, storage_type); +} + template void from_treelite(const raft::handle_t& handle, forest_t* pforest, @@ -668,24 +695,15 @@ void from_treelite(const raft::handle_t& handle, } switch (storage_type) { - case storage_type_t::DENSE: { - tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_object(); - tl2fil.init_forest(handle, pforest, storage_type); + case storage_type_t::DENSE: + convert(model, *tl_params, handle, pforest, storage_type); break; - } - case storage_type_t::SPARSE: { - tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_object(); - tl2fil.init_forest(handle, pforest, storage_type); + case storage_type_t::SPARSE: + convert(model, *tl_params, handle, pforest, storage_type); break; - } - case storage_type_t::SPARSE8: { - tl2fil_t tl2fil(model, *tl_params); - tl2fil.init_object(); - tl2fil.init_forest(handle, pforest, storage_type); + case storage_type_t::SPARSE8: + convert(model, *tl_params, handle, pforest, storage_type); break; - } default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); } } From 06bba79422fcbc0fb20725074e2ff30b14eb92c5 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 18 Nov 2021 20:25:41 -0800 Subject: [PATCH 13/24] addressed review comments --- cpp/src/fil/common.cuh | 14 ++++++++------ cpp/src/fil/treelite_import.cu | 32 ++++++++++++-------------------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index a7485e38ed..9c601c0716 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -126,18 +126,20 @@ struct sparse_forest; template struct node_traits { - using storage = sparse_storage; - using forest = sparse_forest; - static const bool IS_DENSE = false; + using storage = sparse_storage; + using forest = sparse_forest; + static const bool IS_DENSE = false; + static const bool storage_type_enum = std::is_same() ? SPARSE : SPARSE8; template static void check(const treelite::ModelImpl& model); }; template <> struct node_traits { - using storage = dense_storage; - using forest = dense_forest; - static const bool IS_DENSE = true; + using storage = dense_storage; + using forest = dense_forest; + static const bool IS_DENSE = true; + static const bool storage_type_enum = DENSE; template static void check(const treelite::ModelImpl& model) { diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 3af8e81e8b..fe77b96fb9 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -351,6 +351,7 @@ int tree2fil(std::vector& nodes, { typedef std::pair pair_t; std::stack stack; + // needed if the node is sparse, to place within memory for the FIL tree int sparse_index = root + 1; stack.push(pair_t(tree_root(tree), 0)); while (!stack.empty()) { @@ -631,7 +632,7 @@ struct tl2fil_t { } /// initializes FIL forest object, to be ready to infer - void init_forest(const raft::handle_t& handle, forest_t* pforest, storage_type_t storage_type) + void init_forest(const raft::handle_t& handle, forest_t* pforest) { ML::fil::init( handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), ¶ms_); @@ -639,22 +640,20 @@ struct tl2fil_t { // but destructed at the end of this function CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); if (tl_params_.pforest_shape_str) { - *tl_params_.pforest_shape_str = - sprintf_shape(model_, storage_type, nodes_, roots_, cat_sets_); + *tl_params_.pforest_shape_str = sprintf_shape(model_, nodes_, roots_, cat_sets_); } } }; template -void convert(const tl::ModelImpl& model, - const treelite_params_t& tl_params, - const raft::handle_t& handle, +void convert(const raft::handle_t& handle, forest_t* pforest, - storage_type_t storage_type) + const tl::ModelImpl& model, + const treelite_params_t& tl_params) { tl2fil_t tl2fil(model, tl_params); tl2fil.init(); - tl2fil.init_forest(handle, pforest, storage_type); + tl2fil.init_forest(handle, pforest); } template @@ -695,15 +694,9 @@ void from_treelite(const raft::handle_t& handle, } switch (storage_type) { - case storage_type_t::DENSE: - convert(model, *tl_params, handle, pforest, storage_type); - break; - case storage_type_t::SPARSE: - convert(model, *tl_params, handle, pforest, storage_type); - break; - case storage_type_t::SPARSE8: - convert(model, *tl_params, handle, pforest, storage_type); - break; + case storage_type_t::DENSE: convert(handle, pforest, model, *tl_params); break; + case storage_type_t::SPARSE: convert(handle, pforest, model, *tl_params); break; + case storage_type_t::SPARSE8: convert(handle, pforest, model, *tl_params); break; default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); } } @@ -723,7 +716,6 @@ void from_treelite(const raft::handle_t& handle, // allocates caller-owned char* using malloc() template char* sprintf_shape(const tl::ModelImpl& model, - storage_type_t storage, const std::vector& nodes, const std::vector& trees, const cat_sets_owner cat_sets) @@ -732,8 +724,8 @@ char* sprintf_shape(const tl::ModelImpl& model, double size_mb = (trees.size() * sizeof(trees.front()) + nodes.size() * sizeof(nodes.front()) + cat_sets.bits.size()) / 1e6; - forest_shape << storage_type_repr[storage] << " model size " << std::setprecision(2) << size_mb - << " MB" << std::endl; + forest_shape << storage_type_repr[node_traits::storage_type_enum] << " model size " + << std::setprecision(2) << size_mb << " MB" << std::endl; if (cat_sets.bits.size() > 0) { forest_shape << "number of categorical nodes for each feature id: {"; std::size_t total_cat_nodes = 0; From d52b89614cee102d7edd0ae6a5187e4c09eb8021 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 18 Nov 2021 20:36:08 -0800 Subject: [PATCH 14/24] fixed enum->bool bug --- cpp/src/fil/common.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 9c601c0716..466369cbf9 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -129,7 +129,8 @@ struct node_traits { using storage = sparse_storage; using forest = sparse_forest; static const bool IS_DENSE = false; - static const bool storage_type_enum = std::is_same() ? SPARSE : SPARSE8; + static const storage_type storage_type_enum = + std::is_same() ? SPARSE : SPARSE8; template static void check(const treelite::ModelImpl& model); }; @@ -139,7 +140,7 @@ struct node_traits { using storage = dense_storage; using forest = dense_forest; static const bool IS_DENSE = true; - static const bool storage_type_enum = DENSE; + static const storage_type storage_type_enum = DENSE; template static void check(const treelite::ModelImpl& model) { From 70cd427205284413b5a0ce89df3981bc025ac85c Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 23 Nov 2021 14:42:17 -0800 Subject: [PATCH 15/24] style --- cpp/src/fil/common.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 466369cbf9..7a596fa08e 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -126,9 +126,9 @@ struct sparse_forest; template struct node_traits { - using storage = sparse_storage; - using forest = sparse_forest; - static const bool IS_DENSE = false; + using storage = sparse_storage; + using forest = sparse_forest; + static const bool IS_DENSE = false; static const storage_type storage_type_enum = std::is_same() ? SPARSE : SPARSE8; template @@ -137,9 +137,9 @@ struct node_traits { template <> struct node_traits { - using storage = dense_storage; - using forest = dense_forest; - static const bool IS_DENSE = true; + using storage = dense_storage; + using forest = dense_forest; + static const bool IS_DENSE = true; static const storage_type storage_type_enum = DENSE; template static void check(const treelite::ModelImpl& model) From 5f5ce4ec9582ac9dad22b8992e4274425c019e1f Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 23 Nov 2021 16:13:31 -0800 Subject: [PATCH 16/24] typo --- cpp/src/fil/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 7a596fa08e..5ac0b0efb7 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -129,7 +129,7 @@ struct node_traits { using storage = sparse_storage; using forest = sparse_forest; static const bool IS_DENSE = false; - static const storage_type storage_type_enum = + static const storage_type_t storage_type_enum = std::is_same() ? SPARSE : SPARSE8; template static void check(const treelite::ModelImpl& model); @@ -140,7 +140,7 @@ struct node_traits { using storage = dense_storage; using forest = dense_forest; static const bool IS_DENSE = true; - static const storage_type storage_type_enum = DENSE; + static const storage_type_t storage_type_enum = DENSE; template static void check(const treelite::ModelImpl& model) { From eea9babc149bc82ba2213061763b198247ff166f Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Tue, 23 Nov 2021 17:31:56 -0800 Subject: [PATCH 17/24] style --- cpp/src/fil/common.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 5ac0b0efb7..446fd78e7e 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -137,9 +137,9 @@ struct node_traits { template <> struct node_traits { - using storage = dense_storage; - using forest = dense_forest; - static const bool IS_DENSE = true; + using storage = dense_storage; + using forest = dense_forest; + static const bool IS_DENSE = true; static const storage_type_t storage_type_enum = DENSE; template static void check(const treelite::ModelImpl& model) From 6e12723cb02eb673b87309b36c74470fae256b85 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Wed, 1 Dec 2021 16:19:29 -0800 Subject: [PATCH 18/24] unified dense adn sparse tests; test cases are almost entirely disjoint --- cpp/test/sg/fil_test.cu | 69 +++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 41 deletions(-) diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index db7087c140..b86ee5bb1d 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -643,31 +643,8 @@ class BaseFilTest : public testing::TestWithParam { FilTestParams ps; }; -class PredictDenseFilTest : public BaseFilTest { - protected: - void init_forest(fil::forest_t* pforest) override - { - // init FIL model - fil::forest_params_t fil_ps; - fil_ps.depth = ps.depth; - fil_ps.num_trees = ps.num_trees; - fil_ps.num_cols = ps.num_cols; - fil_ps.algo = ps.algo; - fil_ps.output = ps.output; - fil_ps.threshold = ps.threshold; - fil_ps.global_bias = ps.global_bias; - fil_ps.leaf_algo = ps.leaf_algo; - fil_ps.num_classes = ps.num_classes; - fil_ps.blocks_per_sm = ps.blocks_per_sm; - fil_ps.threads_per_tree = ps.threads_per_tree; - fil_ps.n_items = ps.n_items; - - fil::init(handle, pforest, cat_sets_h.accessor(), vector_leaf, nullptr, nodes.data(), &fil_ps); - } -}; - template -class BasePredictSparseFilTest : public BaseFilTest { +class BasePredictFilTest : public BaseFilTest { protected: void dense2sparse_node(const fil::dense_node* dense_root, int i_dense, @@ -714,22 +691,31 @@ class BasePredictSparseFilTest : public BaseFilTest { void init_forest(fil::forest_t* pforest) override { + if constexpr (!node_traits::IS_DENSE) { + dense2sparse(); + } else { + sparse_nodes = nodes; + // fil_params.num_nodes = forest_num_nodes(); + } + ASSERT(sparse_nodes.size() < std::size_t(INT_MAX), "generated too many nodes"); + // init FIL model - fil::forest_params_t fil_params; - fil_params.num_trees = ps.num_trees; - fil_params.num_cols = ps.num_cols; - fil_params.algo = ps.algo; - fil_params.output = ps.output; - fil_params.threshold = ps.threshold; - fil_params.global_bias = ps.global_bias; - fil_params.leaf_algo = ps.leaf_algo; - fil_params.num_classes = ps.num_classes; - fil_params.blocks_per_sm = ps.blocks_per_sm; - fil_params.threads_per_tree = ps.threads_per_tree; - fil_params.n_items = ps.n_items; - - dense2sparse(); - fil_params.num_nodes = sparse_nodes.size(); + fil::forest_params_t fil_params = { + .num_nodes = static_cast(sparse_nodes.size()), + .depth = ps.depth, + .num_trees = ps.num_trees, + .num_cols = ps.num_cols, + .leaf_algo = ps.leaf_algo, + .algo = ps.algo, + .output = ps.output, + .threshold = ps.threshold, + .global_bias = ps.global_bias, + .num_classes = ps.num_classes, + .blocks_per_sm = ps.blocks_per_sm, + .threads_per_tree = ps.threads_per_tree, + .n_items = ps.n_items, + }; + fil::init(handle, pforest, cat_sets_h.accessor(), @@ -742,8 +728,9 @@ class BasePredictSparseFilTest : public BaseFilTest { std::vector trees; }; -typedef BasePredictSparseFilTest PredictSparse16FilTest; -typedef BasePredictSparseFilTest PredictSparse8FilTest; +typedef BasePredictFilTest PredictDenseFilTest; +typedef BasePredictFilTest PredictSparse16FilTest; +typedef BasePredictFilTest PredictSparse8FilTest; class TreeliteFilTest : public BaseFilTest { protected: From 32a51032df1a2ffd2e1ae25d75a13dfb4dc31f3e Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Wed, 1 Dec 2021 20:25:00 -0800 Subject: [PATCH 19/24] stray comment --- cpp/test/sg/fil_test.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index b86ee5bb1d..24e253c6cb 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -695,7 +695,6 @@ class BasePredictFilTest : public BaseFilTest { dense2sparse(); } else { sparse_nodes = nodes; - // fil_params.num_nodes = forest_num_nodes(); } ASSERT(sparse_nodes.size() < std::size_t(INT_MAX), "generated too many nodes"); From fe38a39f3fe675d355e8dd1d6761e52c57e2ea2a Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 16 Dec 2021 16:08:42 -0800 Subject: [PATCH 20/24] addressed review comments --- cpp/src/fil/treelite_import.cu | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index fe77b96fb9..c57127843f 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -588,25 +588,22 @@ struct tl2fil_t { void init() { + static const bool IS_DENSE = node_traits::IS_DENSE; tl2fil_common(¶ms_, model_, &tl_params_); node_traits::check(model_); - size_t num_trees = model_.trees.size(); + std::size_t num_trees = model_.trees.size(); - roots_.reserve(num_trees + 1); - roots_.push_back(0); - for (size_t i = 0; i < num_trees; ++i) { - int num_nodes = node_traits::IS_DENSE ? tree_num_nodes(params_.depth) - : model_.trees[i].num_nodes; - roots_.push_back(num_nodes + roots_.back()); + std::size_t total_nodes = 0; + roots_.reserve(num_trees); + for (auto& tree : model_.trees) { + roots_.push_back(total_nodes); + total_nodes += IS_DENSE ? tree_num_nodes(params_.depth) : tree.num_nodes; } - size_t total_nodes = roots_.back(); - roots_.pop_back(); if (params_.leaf_algo == VECTOR_LEAF) { - size_t max_leaves = node_traits::IS_DENSE - ? num_trees * (tree_num_nodes(params_.depth) + 1) / 2 - : (total_nodes + num_trees) / 2; + std::size_t max_leaves = IS_DENSE ? num_trees * (tree_num_nodes(params_.depth) + 1) / 2 + : (total_nodes + num_trees) / 2; vector_leaf_.resize(max_leaves * params_.num_classes); } From d9d888f0feb3292da36ad3bafaf94d6679a511a7 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Thu, 16 Dec 2021 17:11:50 -0800 Subject: [PATCH 21/24] moved node_traits from common.cuh to internal.cuh to use in fil_tests.cu --- cpp/src/fil/common.cuh | 33 --------------------------------- cpp/src/fil/internal.cuh | 33 +++++++++++++++++++++++++++++++++ cpp/test/sg/fil_test.cu | 14 ++++++-------- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 446fd78e7e..9df0b6a3df 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -28,12 +28,6 @@ #include "internal.cuh" -// needed for node_traits<...> -namespace treelite { -template -struct ModelImpl; -} - namespace ML { namespace fil { @@ -120,33 +114,6 @@ struct sparse_storage : storage_base { typedef sparse_storage sparse_storage16; typedef sparse_storage sparse_storage8; -struct dense_forest; -template -struct sparse_forest; - -template -struct node_traits { - using storage = sparse_storage; - using forest = sparse_forest; - static const bool IS_DENSE = false; - static const storage_type_t storage_type_enum = - std::is_same() ? SPARSE : SPARSE8; - template - static void check(const treelite::ModelImpl& model); -}; - -template <> -struct node_traits { - using storage = dense_storage; - using forest = dense_forest; - static const bool IS_DENSE = true; - static const storage_type_t storage_type_enum = DENSE; - template - static void check(const treelite::ModelImpl& model) - { - } -}; - /// all model parameters mostly required to compute shared memory footprint, /// also the footprint itself struct shmem_size_params { diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 885a9772d2..92a67650d1 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -33,6 +33,12 @@ namespace raft { class handle_t; } +// needed for node_traits<...> +namespace treelite { +template +struct ModelImpl; +} + namespace ML { namespace fil { @@ -215,6 +221,33 @@ struct alignas(8) sparse_node8 : base_node { __host__ __device__ int left(int curr) const { return left_index(); } }; +struct dense_forest; +template +struct sparse_forest; + +template +struct node_traits { + using storage = sparse_storage; + using forest = sparse_forest; + static const bool IS_DENSE = false; + static const storage_type_t storage_type_enum = + std::is_same() ? SPARSE : SPARSE8; + template + static void check(const treelite::ModelImpl& model); +}; + +template <> +struct node_traits { + using storage = dense_storage; + using forest = dense_forest; + static const bool IS_DENSE = true; + static const storage_type_t storage_type_enum = DENSE; + template + static void check(const treelite::ModelImpl& model) + { + } +}; + /** leaf_algo_t describes what the leaves in a FIL forest store (predict) and how FIL aggregates them into class margins/regression result/best class **/ diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 26e34c9ff4..43df4275f2 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -694,16 +694,14 @@ class BasePredictFilTest : public BaseFilTest { void init_forest(fil::forest_t* pforest) override { - if constexpr (!node_traits::IS_DENSE) { - dense2sparse(); - } else { - sparse_nodes = nodes; - } - ASSERT(sparse_nodes.size() < std::size_t(INT_MAX), "generated too many nodes"); + constexpr bool IS_DENSE = node_traits::IS_DENSE; + if constexpr (!IS_DENSE) dense2sparse(); + std::vector& init_nodes = IS_DENSE ? nodes : sparse_nodes; + ASSERT(init_nodes.size() < std::size_t(INT_MAX), "generated too many nodes"); // init FIL model fil::forest_params_t fil_params = { - .num_nodes = static_cast(sparse_nodes.size()), + .num_nodes = static_cast(init_nodes.size()), .depth = ps.depth, .num_trees = ps.num_trees, .num_cols = ps.num_cols, @@ -723,7 +721,7 @@ class BasePredictFilTest : public BaseFilTest { cat_sets_h.accessor(), vector_leaf, trees.data(), - sparse_nodes.data(), + init_nodes.data(), &fil_params); } std::vector sparse_nodes; From f36304d0b04ba11a6197b56fa40d557cfb7073c1 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 17 Dec 2021 00:00:52 -0800 Subject: [PATCH 22/24] types --- cpp/src/fil/internal.cuh | 4 ++++ cpp/test/sg/fil_test.cu | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 92a67650d1..9a0ff52a1e 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -225,6 +225,10 @@ struct dense_forest; template struct sparse_forest; +struct dense_storage; +template +struct sparse_storage; + template struct node_traits { using storage = sparse_storage; diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 43df4275f2..898d810295 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -695,8 +695,13 @@ class BasePredictFilTest : public BaseFilTest { void init_forest(fil::forest_t* pforest) override { constexpr bool IS_DENSE = node_traits::IS_DENSE; - if constexpr (!IS_DENSE) dense2sparse(); - std::vector& init_nodes = IS_DENSE ? nodes : sparse_nodes; + std::vector& init_nodes; + if constexpr (!IS_DENSE) { + dense2sparse(); + init_nodes = sparse_nodes; + } else { + init_nodes = nodes; + } ASSERT(init_nodes.size() < std::size_t(INT_MAX), "generated too many nodes"); // init FIL model From 8de22be3d2ac192aad6a8e271777cb938acecdb0 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 17 Dec 2021 10:50:24 -0800 Subject: [PATCH 23/24] ref -> val --- cpp/test/sg/fil_test.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 898d810295..5fc78d0cb2 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -695,7 +695,7 @@ class BasePredictFilTest : public BaseFilTest { void init_forest(fil::forest_t* pforest) override { constexpr bool IS_DENSE = node_traits::IS_DENSE; - std::vector& init_nodes; + std::vector init_nodes; if constexpr (!IS_DENSE) { dense2sparse(); init_nodes = sparse_nodes; From 5679fe1ca29b1dbb5eacedfe8810042828581648 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 17 Dec 2021 14:21:23 -0800 Subject: [PATCH 24/24] fix conflict resolution --- cpp/src/fil/common.cuh | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/cpp/src/fil/common.cuh b/cpp/src/fil/common.cuh index 446fd78e7e..9df0b6a3df 100644 --- a/cpp/src/fil/common.cuh +++ b/cpp/src/fil/common.cuh @@ -28,12 +28,6 @@ #include "internal.cuh" -// needed for node_traits<...> -namespace treelite { -template -struct ModelImpl; -} - namespace ML { namespace fil { @@ -120,33 +114,6 @@ struct sparse_storage : storage_base { typedef sparse_storage sparse_storage16; typedef sparse_storage sparse_storage8; -struct dense_forest; -template -struct sparse_forest; - -template -struct node_traits { - using storage = sparse_storage; - using forest = sparse_forest; - static const bool IS_DENSE = false; - static const storage_type_t storage_type_enum = - std::is_same() ? SPARSE : SPARSE8; - template - static void check(const treelite::ModelImpl& model); -}; - -template <> -struct node_traits { - using storage = dense_storage; - using forest = dense_forest; - static const bool IS_DENSE = true; - static const storage_type_t storage_type_enum = DENSE; - template - static void check(const treelite::ModelImpl& model) - { - } -}; - /// all model parameters mostly required to compute shared memory footprint, /// also the footprint itself struct shmem_size_params {