Skip to content

Commit

Permalink
Universal Treelite tree walk function for FIL (#4407)
Browse files Browse the repository at this point in the history
During treelite import, FIL walks each tree 4 times for 4 different purposes. As a first step in shrinking the code, rewrite all the walks in terms of a universal Treelite tree walk function, with two lambdas: one for inner (branch) node and one for leaf node visitation.

Andy and Philip reviewed this PR presumably from this link:
https://github.com/levsnv/cuml/pull/2/files?diff=unified&w=1
The contents are the same, it merely was parallelyzing reviews by making a PR vs a different PR (branch), which is now merged.

Authors:
  - Levs Dolgovs (https://github.com/levsnv)

Approvers:
  - Andy Adinets (https://github.com/canonizer)
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - William Hicks (https://github.com/wphicks)

URL: #4407
  • Loading branch information
levsnv authored Dec 18, 2021
1 parent 81648cb commit 03132e8
Showing 1 changed file with 127 additions and 126 deletions.
253 changes: 127 additions & 126 deletions cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,73 @@ int tree_root(const tl::Tree<T, L>& tree)
return 0; // Treelite format assumes that the root is 0
}

template <typename T, typename L>
inline int max_depth(const tl::Tree<T, L>& tree)
// a no-op placeholder for values and callables alike
struct empty {
template <typename... Args>
void operator()(Args...)
{
}
};

/** walk a Treelite tree, visiting each inner node with visit_inner and each leaf node with
visit_leaf. See walk_tree::element::state documentation for how TraversalState is retained
during traversal. Any per-tree state during traversal should be captured by the lambdas themselves.
visit_inner(int node_id, TraversalState state) should return a pair of new states, one for
each child node. visit_leaf(int, TraversalState) returns nothing.
**/
template <typename T, typename L, typename InnerFunc, typename LeafFunc = empty>
inline void walk_tree(const tl::Tree<T, L>& tree,
InnerFunc visit_inner,
LeafFunc visit_leaf = empty())
{
// trees of this depth aren't used, so it most likely means bad input data,
// e.g. cycles in the forest
const int DEPTH_LIMIT = 500;
int root_index = tree_root(tree);
typedef std::pair<int, int> pair_t;
std::stack<pair_t> stack;
stack.push(pair_t(root_index, 0));
int max_depth = 0;
while (!stack.empty()) {
const pair_t& pair = stack.top();
int node_id = pair.first;
int depth = pair.second;
stack.pop();
while (!tree.IsLeaf(node_id)) {
stack.push(pair_t(tree.LeftChild(node_id), depth + 1));
node_id = tree.RightChild(node_id);
depth++;
ASSERT(depth < DEPTH_LIMIT, "depth limit reached, might be a cycle in the tree");
if constexpr (std::is_invocable<InnerFunc, int>()) {
/// wrapper for empty path state
walk_tree(
tree,
[&](int nid, empty val) {
visit_inner(nid);
return std::pair<empty, empty>();
},
[&](int nid, empty val) { visit_leaf(nid); });
} else {
using TraversalState = decltype(visit_inner(int(), {}).first);
/// needed to visit a node
struct element {
int tl_node_id;
/// Retained while visiting nodes on a single path from root to leaf.
/// This generalizes the node index that's carried over during inference tree traversal.
TraversalState state;
};
std::stack<element> stack;
stack.push(element{tree_root(tree), TraversalState()});
while (!stack.empty()) {
element i = stack.top();
stack.pop();
while (!tree.IsLeaf(i.tl_node_id)) {
auto [left_state, right_state] = visit_inner(i.tl_node_id, i.state);
stack.push(element{tree.LeftChild(i.tl_node_id), left_state});
i = element{tree.RightChild(i.tl_node_id), right_state};
}
visit_leaf(i.tl_node_id, i.state);
}
// only need to update depth for leaves
max_depth = std::max(max_depth, depth);
}
return max_depth;
}

template <typename T, typename L>
inline int max_depth(const tl::Tree<T, L>& tree)
{
int tree_depth = 0;
walk_tree(
tree,
[](int node_id, int node_depth) {
// trees of this depth aren't used, so it most likely means bad input data,
// e.g. cycles in the forest
constexpr int DEPTH_LIMIT = 500;
ASSERT(node_depth < DEPTH_LIMIT, "node_depth limit reached, might be a cycle in the tree");
return std::pair(node_depth + 1, node_depth + 1);
},
[&](int node_id, int node_depth) { tree_depth = std::max(node_depth, tree_depth); });
return tree_depth;
}

template <typename T, typename L>
Expand Down Expand Up @@ -122,35 +163,26 @@ template <typename T, typename L>
inline std::vector<cat_feature_counters> cat_counter_vec(const tl::Tree<T, L>& tree, int n_cols)
{
std::vector<cat_feature_counters> res(n_cols);
std::stack<int> stack;
stack.push(tree_root(tree));
while (!stack.empty()) {
int node_id = stack.top();
stack.pop();
while (!tree.IsLeaf(node_id)) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) {
std::vector<std::uint32_t> mmv = tree.MatchingCategories(node_id);
int max_matching_cat;
if (mmv.size() > 0) {
// in `struct cat_feature_counters` and GPU structures, max matching category is an int
// cast is safe because all precise int floats fit into ints, which are asserted to be 32
// bits
max_matching_cat = mmv.back();
ASSERT(max_matching_cat <= MAX_FIL_INT_FLOAT,
"FIL cannot infer on "
"more than %d matching categories",
MAX_FIL_INT_FLOAT);
} else {
max_matching_cat = -1;
}
cat_feature_counters& counters = res[tree.SplitIndex(node_id)];
counters =
cat_feature_counters::combine(counters, cat_feature_counters{max_matching_cat, 1});
walk_tree(tree, [&](int node_id) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) {
std::vector<std::uint32_t> mmv = tree.MatchingCategories(node_id);
int max_matching_cat;
if (mmv.size() > 0) {
// in `struct cat_feature_counters` and GPU structures, int(max_matching_cat) is safe
// because all precise int floats fit into ints, which are asserted to be 32 bits
max_matching_cat = mmv.back();
ASSERT(max_matching_cat <= MAX_FIL_INT_FLOAT,
"FIL cannot infer on "
"more than %d matching categories",
MAX_FIL_INT_FLOAT);
} else {
max_matching_cat = -1;
}
stack.push(tree.LeftChild(node_id));
node_id = tree.RightChild(node_id);
cat_feature_counters& counters = res[tree.SplitIndex(node_id)];
counters = cat_feature_counters::combine(counters, cat_feature_counters{max_matching_cat, 1});
}
}
});

return res;
}

Expand All @@ -159,21 +191,12 @@ template <typename T, typename L>
inline std::size_t bit_pool_size(const tl::Tree<T, L>& tree, const categorical_sets& cat_sets)
{
std::size_t size = 0;
std::stack<int> stack;
stack.push(tree_root(tree));
while (!stack.empty()) {
int node_id = stack.top();
stack.pop();
while (!tree.IsLeaf(node_id)) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(node_id).size() > 0) {
int fid = tree.SplitIndex(node_id);
size += cat_sets.sizeof_mask(fid);
}
stack.push(tree.LeftChild(node_id));
node_id = tree.RightChild(node_id);
walk_tree(tree, [&](int node_id) {
if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(node_id).size() > 0) {
size += cat_sets.sizeof_mask(tree.SplitIndex(node_id));
}
}
});
return size;
}

Expand All @@ -200,16 +223,14 @@ cat_sets_owner allocate_cat_sets_owner(const tl::ModelImpl<T, L>& model)
return cat_sets;
}

void adjust_threshold(
float* pthreshold, int* tl_left, int* tl_right, bool* default_left, tl::Operator comparison_op)
void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op)
{
// in treelite (take left node if val [op] threshold),
// the meaning of the condition is reversed compared to FIL;
// thus, "<" in treelite corresonds to comparison ">=" used by FIL
// https://github.com/dmlc/treelite/blob/master/include/treelite/tree.h#L243
if (isnan(*pthreshold)) {
std::swap(*tl_left, *tl_right);
*default_left = !*default_left;
*swap_child_nodes = !*swap_child_nodes;
return;
}
switch (comparison_op) {
Expand All @@ -224,8 +245,7 @@ void adjust_threshold(
*pthreshold = std::nextafterf(*pthreshold, std::numeric_limits<float>::infinity());
case tl::Operator::kGE:
// swap left and right
std::swap(*tl_left, *tl_right);
*default_left = !*default_left;
*swap_child_nodes = !*swap_child_nodes;
break;
default: ASSERT(false, "only <, >, <= and >= comparisons are supported");
}
Expand Down Expand Up @@ -292,8 +312,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node,
template <typename fil_node_t>
struct conversion_state {
fil_node_t node;
int tl_left;
int tl_right;
bool swap_child_nodes;
};

// modifies cat_sets
Expand All @@ -309,16 +328,13 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool default_left = tree.DefaultLeft(tl_node_id);
bool swap_child_nodes = false;
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id));
adjust_threshold(&split.f, &swap_child_nodes, tree.ComparisonOp(tl_node_id));
} else if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical) {
// for FIL, the list of categories is always for the right child
if (!tree.CategoriesListRightChild(tl_node_id)) {
std::swap(tl_left, tl_right);
default_left = !default_left;
}
swap_child_nodes = !tree.CategoriesListRightChild(tl_node_id);
if (tree.MatchingCategories(tl_node_id).size() > 0) {
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
Expand All @@ -335,8 +351,9 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
}
bool default_left = tree.DefaultLeft(tl_node_id) ^ swap_child_nodes;
fil_node_t node(val_t{}, split, feature_id, default_left, false, is_categorical, fil_left_child);
return conversion_state<fil_node_t>{node, tl_left, tl_right};
return conversion_state<fil_node_t>{node, swap_child_nodes};
}

template <typename fil_node_t, typename T, typename L>
Expand All @@ -349,69 +366,53 @@ int tree2fil(std::vector<fil_node_t>& nodes,
std::size_t* leaf_counter,
cat_sets_owner* cat_sets)
{
typedef std::pair<int, int> pair_t;
std::stack<pair_t> stack;
// needed if the node is sparse, to place within memory for the FIL tree
int sparse_index = root + 1;
stack.push(pair_t(tree_root(tree), 0));
while (!stack.empty()) {
const pair_t& top = stack.top();
int node_id = top.first;
int cur = top.second;
stack.pop();

while (!tree.IsLeaf(node_id)) {
int sparse_index = 1;
walk_tree(
tree,
[&](int node_id, int fil_node_id) {
// reserve space for child nodes
// left is the offset of the left child node relative to the tree root
// in the array of all nodes of the FIL sparse forest
int left = node_traits<fil_node_t>::IS_DENSE ? 2 * cur + 1 : sparse_index - root;
int left = node_traits<fil_node_t>::IS_DENSE ? 2 * fil_node_id + 1 : sparse_index;
sparse_index += 2;
conversion_state<fil_node_t> cs = tl2fil_inner_node<fil_node_t>(
left, tree, node_id, cat_sets, &cat_sets->bit_pool_offsets[tree_idx]);
nodes[root + cur] = cs.node;
// push child nodes into the stack
stack.push(pair_t(cs.tl_right, left + 1));
// stack.push(pair_t(tl_left, left));
node_id = cs.tl_left;
cur = left;
}

// leaf node
nodes[root + cur] = fil_node_t({}, {}, 0, false, true, false, 0);
tl2fil_leaf_payload(
&nodes[root + cur], root + cur, tree, node_id, forest_params, vector_leaf, leaf_counter);
}

nodes[root + fil_node_id] = cs.node;

return cs.swap_child_nodes ? std::pair(left + 1, left) : std::pair(left, left + 1);
},
[&](int node_id, int fil_node_id) {
nodes[root + fil_node_id] = fil_node_t({}, {}, 0, false, true, false, 0);
tl2fil_leaf_payload(&nodes[root + fil_node_id],
root + fil_node_id,
tree,
node_id,
forest_params,
vector_leaf,
leaf_counter);
});
return root;
}

struct level_entry {
int n_branch_nodes, n_leaves;
};
typedef std::pair<int, int> pair_t;
// hist has branch and leaf count given depth
template <typename T, typename L>
inline void tree_depth_hist(const tl::Tree<T, L>& tree, std::vector<level_entry>& hist)
inline void node_depth_hist(const tl::Tree<T, L>& tree, std::vector<level_entry>& hist)
{
std::stack<pair_t> stack; // {tl_id, depth}
stack.push({tree_root(tree), 0});
while (!stack.empty()) {
const pair_t& top = stack.top();
int node_id = top.first;
int depth = top.second;
stack.pop();

while (!tree.IsLeaf(node_id)) {
if (static_cast<std::size_t>(depth) >= hist.size()) hist.resize(depth + 1, {0, 0});
walk_tree(
tree,
[&](int node_id, std::size_t depth) {
if (depth >= hist.size()) hist.resize(depth + 1, {0, 0});
hist[depth].n_branch_nodes++;
stack.push({tree.LeftChild(node_id), depth + 1});
node_id = tree.RightChild(node_id);
depth++;
}

if (static_cast<std::size_t>(depth) >= hist.size()) hist.resize(depth + 1, {0, 0});
hist[depth].n_leaves++;
}
return std::pair(depth + 1, depth + 1);
},
[&](int node_id, std::size_t depth) {
if (depth >= hist.size()) hist.resize(depth + 1, {0, 0});
hist[depth].n_leaves++;
});
}

template <typename T, typename L>
Expand All @@ -420,7 +421,7 @@ std::stringstream depth_hist_and_max(const tl::ModelImpl<T, L>& model)
using namespace std;
vector<level_entry> hist;
for (const auto& tree : model.trees)
tree_depth_hist(tree, hist);
node_depth_hist(tree, hist);

int min_leaf_depth = -1, leaves_times_depth = 0, total_branches = 0, total_leaves = 0;
stringstream forest_shape;
Expand Down

0 comments on commit 03132e8

Please sign in to comment.