diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index c57127843f..3fcd7f0775 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -69,32 +69,73 @@ int tree_root(const tl::Tree& tree) return 0; // Treelite format assumes that the root is 0 } -template -inline int max_depth(const tl::Tree& tree) +// a no-op placeholder for values and callables alike +struct empty { + template + void operator()(Args...) + { + } +}; + +/** walk a Treelite tree, visiting each inner node with visit_inner and each leaf node with + visit_leaf. See walk_tree::element::state documentation for how TraversalState is retained +during traversal. Any per-tree state during traversal should be captured by the lambdas themselves. + visit_inner(int node_id, TraversalState state) should return a pair of new states, one for +each child node. visit_leaf(int, TraversalState) returns nothing. +**/ +template +inline void walk_tree(const tl::Tree& tree, + InnerFunc visit_inner, + LeafFunc visit_leaf = empty()) { - // trees of this depth aren't used, so it most likely means bad input data, - // e.g. cycles in the forest - const int DEPTH_LIMIT = 500; - int root_index = tree_root(tree); - typedef std::pair pair_t; - std::stack stack; - stack.push(pair_t(root_index, 0)); - int max_depth = 0; - while (!stack.empty()) { - const pair_t& pair = stack.top(); - int node_id = pair.first; - int depth = pair.second; - stack.pop(); - while (!tree.IsLeaf(node_id)) { - stack.push(pair_t(tree.LeftChild(node_id), depth + 1)); - node_id = tree.RightChild(node_id); - depth++; - ASSERT(depth < DEPTH_LIMIT, "depth limit reached, might be a cycle in the tree"); + if constexpr (std::is_invocable()) { + /// wrapper for empty path state + walk_tree( + tree, + [&](int nid, empty val) { + visit_inner(nid); + return std::pair(); + }, + [&](int nid, empty val) { visit_leaf(nid); }); + } else { + using TraversalState = decltype(visit_inner(int(), {}).first); + /// needed to visit a node + struct element { + int tl_node_id; + /// Retained while visiting nodes on a single path from root to leaf. + /// This generalizes the node index that's carried over during inference tree traversal. + TraversalState state; + }; + std::stack stack; + stack.push(element{tree_root(tree), TraversalState()}); + while (!stack.empty()) { + element i = stack.top(); + stack.pop(); + while (!tree.IsLeaf(i.tl_node_id)) { + auto [left_state, right_state] = visit_inner(i.tl_node_id, i.state); + stack.push(element{tree.LeftChild(i.tl_node_id), left_state}); + i = element{tree.RightChild(i.tl_node_id), right_state}; + } + visit_leaf(i.tl_node_id, i.state); } - // only need to update depth for leaves - max_depth = std::max(max_depth, depth); } - return max_depth; +} + +template +inline int max_depth(const tl::Tree& tree) +{ + int tree_depth = 0; + walk_tree( + tree, + [](int node_id, int node_depth) { + // trees of this depth aren't used, so it most likely means bad input data, + // e.g. cycles in the forest + constexpr int DEPTH_LIMIT = 500; + ASSERT(node_depth < DEPTH_LIMIT, "node_depth limit reached, might be a cycle in the tree"); + return std::pair(node_depth + 1, node_depth + 1); + }, + [&](int node_id, int node_depth) { tree_depth = std::max(node_depth, tree_depth); }); + return tree_depth; } template @@ -122,35 +163,26 @@ template inline std::vector cat_counter_vec(const tl::Tree& tree, int n_cols) { std::vector res(n_cols); - std::stack stack; - stack.push(tree_root(tree)); - while (!stack.empty()) { - int node_id = stack.top(); - stack.pop(); - while (!tree.IsLeaf(node_id)) { - if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) { - std::vector mmv = tree.MatchingCategories(node_id); - int max_matching_cat; - if (mmv.size() > 0) { - // in `struct cat_feature_counters` and GPU structures, max matching category is an int - // cast is safe because all precise int floats fit into ints, which are asserted to be 32 - // bits - max_matching_cat = mmv.back(); - ASSERT(max_matching_cat <= MAX_FIL_INT_FLOAT, - "FIL cannot infer on " - "more than %d matching categories", - MAX_FIL_INT_FLOAT); - } else { - max_matching_cat = -1; - } - cat_feature_counters& counters = res[tree.SplitIndex(node_id)]; - counters = - cat_feature_counters::combine(counters, cat_feature_counters{max_matching_cat, 1}); + walk_tree(tree, [&](int node_id) { + if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) { + std::vector mmv = tree.MatchingCategories(node_id); + int max_matching_cat; + if (mmv.size() > 0) { + // in `struct cat_feature_counters` and GPU structures, int(max_matching_cat) is safe + // because all precise int floats fit into ints, which are asserted to be 32 bits + max_matching_cat = mmv.back(); + ASSERT(max_matching_cat <= MAX_FIL_INT_FLOAT, + "FIL cannot infer on " + "more than %d matching categories", + MAX_FIL_INT_FLOAT); + } else { + max_matching_cat = -1; } - stack.push(tree.LeftChild(node_id)); - node_id = tree.RightChild(node_id); + cat_feature_counters& counters = res[tree.SplitIndex(node_id)]; + counters = cat_feature_counters::combine(counters, cat_feature_counters{max_matching_cat, 1}); } - } + }); + return res; } @@ -159,21 +191,12 @@ template inline std::size_t bit_pool_size(const tl::Tree& tree, const categorical_sets& cat_sets) { std::size_t size = 0; - std::stack stack; - stack.push(tree_root(tree)); - while (!stack.empty()) { - int node_id = stack.top(); - stack.pop(); - while (!tree.IsLeaf(node_id)) { - if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical && - tree.MatchingCategories(node_id).size() > 0) { - int fid = tree.SplitIndex(node_id); - size += cat_sets.sizeof_mask(fid); - } - stack.push(tree.LeftChild(node_id)); - node_id = tree.RightChild(node_id); + walk_tree(tree, [&](int node_id) { + if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical && + tree.MatchingCategories(node_id).size() > 0) { + size += cat_sets.sizeof_mask(tree.SplitIndex(node_id)); } - } + }); return size; } @@ -200,16 +223,14 @@ cat_sets_owner allocate_cat_sets_owner(const tl::ModelImpl& model) return cat_sets; } -void adjust_threshold( - float* pthreshold, int* tl_left, int* tl_right, bool* default_left, tl::Operator comparison_op) +void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op) { // in treelite (take left node if val [op] threshold), // the meaning of the condition is reversed compared to FIL; // thus, "<" in treelite corresonds to comparison ">=" used by FIL // https://github.com/dmlc/treelite/blob/master/include/treelite/tree.h#L243 if (isnan(*pthreshold)) { - std::swap(*tl_left, *tl_right); - *default_left = !*default_left; + *swap_child_nodes = !*swap_child_nodes; return; } switch (comparison_op) { @@ -224,8 +245,7 @@ void adjust_threshold( *pthreshold = std::nextafterf(*pthreshold, std::numeric_limits::infinity()); case tl::Operator::kGE: // swap left and right - std::swap(*tl_left, *tl_right); - *default_left = !*default_left; + *swap_child_nodes = !*swap_child_nodes; break; default: ASSERT(false, "only <, >, <= and >= comparisons are supported"); } @@ -292,8 +312,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, template struct conversion_state { fil_node_t node; - int tl_left; - int tl_right; + bool swap_child_nodes; }; // modifies cat_sets @@ -309,16 +328,13 @@ conversion_state tl2fil_inner_node(int fil_left_child, int feature_id = tree.SplitIndex(tl_node_id); bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical && tree.MatchingCategories(tl_node_id).size() > 0; - bool default_left = tree.DefaultLeft(tl_node_id); + bool swap_child_nodes = false; if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) { split.f = static_cast(tree.Threshold(tl_node_id)); - adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id)); + adjust_threshold(&split.f, &swap_child_nodes, tree.ComparisonOp(tl_node_id)); } else if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical) { // for FIL, the list of categories is always for the right child - if (!tree.CategoriesListRightChild(tl_node_id)) { - std::swap(tl_left, tl_right); - default_left = !default_left; - } + swap_child_nodes = !tree.CategoriesListRightChild(tl_node_id); if (tree.MatchingCategories(tl_node_id).size() > 0) { int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id); split.idx = *bit_pool_offset; @@ -335,8 +351,9 @@ conversion_state tl2fil_inner_node(int fil_left_child, } else { ASSERT(false, "only numerical and categorical split nodes are supported"); } + bool default_left = tree.DefaultLeft(tl_node_id) ^ swap_child_nodes; fil_node_t node(val_t{}, split, feature_id, default_left, false, is_categorical, fil_left_child); - return conversion_state{node, tl_left, tl_right}; + return conversion_state{node, swap_child_nodes}; } template @@ -349,69 +366,53 @@ int tree2fil(std::vector& nodes, std::size_t* leaf_counter, cat_sets_owner* cat_sets) { - typedef std::pair pair_t; - std::stack stack; // needed if the node is sparse, to place within memory for the FIL tree - int sparse_index = root + 1; - stack.push(pair_t(tree_root(tree), 0)); - while (!stack.empty()) { - const pair_t& top = stack.top(); - int node_id = top.first; - int cur = top.second; - stack.pop(); - - while (!tree.IsLeaf(node_id)) { + int sparse_index = 1; + walk_tree( + tree, + [&](int node_id, int fil_node_id) { // reserve space for child nodes // left is the offset of the left child node relative to the tree root // in the array of all nodes of the FIL sparse forest - int left = node_traits::IS_DENSE ? 2 * cur + 1 : sparse_index - root; + int left = node_traits::IS_DENSE ? 2 * fil_node_id + 1 : sparse_index; sparse_index += 2; conversion_state cs = tl2fil_inner_node( left, tree, node_id, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]); - nodes[root + cur] = cs.node; - // push child nodes into the stack - stack.push(pair_t(cs.tl_right, left + 1)); - // stack.push(pair_t(tl_left, left)); - node_id = cs.tl_left; - cur = left; - } - - // leaf node - nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0); - tl2fil_leaf_payload( - &nodes[root + cur], root + cur, tree, node_id, forest_params, vector_leaf, leaf_counter); - } - + nodes[root + fil_node_id] = cs.node; + + return cs.swap_child_nodes ? std::pair(left + 1, left) : std::pair(left, left + 1); + }, + [&](int node_id, int fil_node_id) { + nodes[root + fil_node_id] = fil_node_t({}, {}, 0, false, true, false, 0); + tl2fil_leaf_payload(&nodes[root + fil_node_id], + root + fil_node_id, + tree, + node_id, + forest_params, + vector_leaf, + leaf_counter); + }); return root; } struct level_entry { int n_branch_nodes, n_leaves; }; -typedef std::pair pair_t; // hist has branch and leaf count given depth template -inline void tree_depth_hist(const tl::Tree& tree, std::vector& hist) +inline void node_depth_hist(const tl::Tree& tree, std::vector& hist) { - std::stack stack; // {tl_id, depth} - stack.push({tree_root(tree), 0}); - while (!stack.empty()) { - const pair_t& top = stack.top(); - int node_id = top.first; - int depth = top.second; - stack.pop(); - - while (!tree.IsLeaf(node_id)) { - if (static_cast(depth) >= hist.size()) hist.resize(depth + 1, {0, 0}); + walk_tree( + tree, + [&](int node_id, std::size_t depth) { + if (depth >= hist.size()) hist.resize(depth + 1, {0, 0}); hist[depth].n_branch_nodes++; - stack.push({tree.LeftChild(node_id), depth + 1}); - node_id = tree.RightChild(node_id); - depth++; - } - - if (static_cast(depth) >= hist.size()) hist.resize(depth + 1, {0, 0}); - hist[depth].n_leaves++; - } + return std::pair(depth + 1, depth + 1); + }, + [&](int node_id, std::size_t depth) { + if (depth >= hist.size()) hist.resize(depth + 1, {0, 0}); + hist[depth].n_leaves++; + }); } template @@ -420,7 +421,7 @@ std::stringstream depth_hist_and_max(const tl::ModelImpl& model) using namespace std; vector hist; for (const auto& tree : model.trees) - tree_depth_hist(tree, hist); + node_depth_hist(tree, hist); int min_leaf_depth = -1, leaves_times_depth = 0, total_branches = 0, total_leaves = 0; stringstream forest_shape;