Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vector leaf performance #311

Merged
merged 5 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ class Tree {
// vector of nodes
ContiguousArray<Node> nodes_;
ContiguousArray<LeafOutputType> leaf_vector_;
ContiguousArray<std::size_t> leaf_vector_offset_;
// Map nid to the start and end index in leaf_vector_
// We could use std::pair, but it is not POD, so easier to use two vectors
// here
ContiguousArray<std::size_t> leaf_vector_begin_;
ContiguousArray<std::size_t> leaf_vector_end_;
ContiguousArray<uint32_t> matching_categories_;
ContiguousArray<std::size_t> matching_categories_offset_;

Expand Down Expand Up @@ -369,8 +373,8 @@ class Tree {
* \param nid ID of node being queried
*/
inline std::vector<LeafOutputType> LeafVector(int nid) const {
const std::size_t offset_begin = leaf_vector_offset_.at(nid);
const std::size_t offset_end = leaf_vector_offset_.at(nid + 1);
const std::size_t offset_begin = leaf_vector_begin_.at(nid);
const std::size_t offset_end = leaf_vector_end_.at(nid);
if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
// Return empty vector, to indicate the lack of leaf vector
return std::vector<LeafOutputType>();
Expand All @@ -385,7 +389,7 @@ class Tree {
* \param nid ID of node being queried
*/
inline bool HasLeafVector(int nid) const {
return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid);
}
/*!
* \brief get threshold of the node
Expand Down
39 changes: 16 additions & 23 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ Tree<ThresholdType, LeafOutputType>::Clone() const {
tree.num_nodes = num_nodes;
tree.nodes_ = nodes_.Clone();
tree.leaf_vector_ = leaf_vector_.Clone();
tree.leaf_vector_offset_ = leaf_vector_offset_.Clone();
tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
tree.leaf_vector_end_ = leaf_vector_end_.Clone();
tree.matching_categories_ = matching_categories_.Clone();
tree.matching_categories_offset_ = matching_categories_offset_.Clone();
return tree;
Expand All @@ -501,7 +502,7 @@ Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
}
}

constexpr std::size_t kNumFramePerTree = 6;
constexpr std::size_t kNumFramePerTree = 7;

template <typename ThresholdType, typename LeafOutputType>
template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
Expand All @@ -512,7 +513,8 @@ Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
scalar_handler(&num_nodes);
composite_array_handler(&nodes_, GetFormatStringForNode());
primitive_array_handler(&leaf_vector_);
primitive_array_handler(&leaf_vector_offset_);
primitive_array_handler(&leaf_vector_begin_);
primitive_array_handler(&leaf_vector_end_);
primitive_array_handler(&matching_categories_);
primitive_array_handler(&matching_categories_offset_);
}
Expand All @@ -528,7 +530,8 @@ Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
throw std::runtime_error("Could not load the correct number of nodes");
}
array_handler(&leaf_vector_);
array_handler(&leaf_vector_offset_);
array_handler(&leaf_vector_begin_);
array_handler(&leaf_vector_end_);
array_handler(&matching_categories_);
array_handler(&matching_categories_offset_);
}
Expand Down Expand Up @@ -614,7 +617,8 @@ Tree<ThresholdType, LeafOutputType>::AllocNode() {
throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
}
for (int nid = nd; nid < num_nodes; ++nid) {
leaf_vector_offset_.PushBack(leaf_vector_offset_.Back());
leaf_vector_begin_.PushBack(0);
leaf_vector_end_.PushBack(0);
matching_categories_offset_.PushBack(matching_categories_offset_.Back());
nodes_.Resize(nodes_.Size() + 1);
nodes_.Back().Init();
Expand All @@ -627,7 +631,8 @@ inline void
Tree<ThresholdType, LeafOutputType>::Init() {
num_nodes = 1;
leaf_vector_.Clear();
leaf_vector_offset_.Resize(2, 0);
leaf_vector_begin_.Resize(1, {});
leaf_vector_end_.Resize(1, {});
matching_categories_.Clear();
matching_categories_offset_.Resize(2, 0);
nodes_.Resize(1);
Expand Down Expand Up @@ -737,24 +742,12 @@ template <typename ThresholdType, typename LeafOutputType>
inline void
Tree<ThresholdType, LeafOutputType>::SetLeafVector(
int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
const std::size_t end_oft = leaf_vector_offset_.Back();
const std::size_t new_end_oft = end_oft + node_leaf_vector.size();
if (end_oft != leaf_vector_.Size()) {
throw std::runtime_error("Invariant violated");
}
if (!std::all_of(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
[end_oft](std::size_t x) { return (x == end_oft); })) {
throw std::runtime_error("Invariant violated");
}
// Hopefully we won't have to move any element as we add leaf vector elements for node nid
std::size_t begin = leaf_vector_.Size();
std::size_t end = begin + node_leaf_vector.size();
leaf_vector_.Extend(node_leaf_vector);
if (new_end_oft != leaf_vector_.Size()) {
throw std::runtime_error("Invariant violated");
}
std::for_each(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
[new_end_oft](std::size_t& x) { x = new_end_oft; });

Node& node = nodes_.at(nid);
leaf_vector_begin_[nid] = begin;
leaf_vector_end_[nid] = end;
Node &node = nodes_.at(nid);
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
node.cleft_ = -1;
node.cright_ = -1;
node.split_type_ = SplitFeatureType::kNone;
Expand Down
10 changes: 5 additions & 5 deletions src/json_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ void SerializeTreeToJSON(WriterType& writer, const Tree<ThresholdType, LeafOutpu
writer.Int(tree.num_nodes);
writer.Key("leaf_vector");
WriteContiguousArray(writer, tree.leaf_vector_);
writer.Key("leaf_vector_offset");
WriteContiguousArray(writer, tree.leaf_vector_offset_);
writer.Key("leaf_vector_begin");
WriteContiguousArray(writer, tree.leaf_vector_begin_);
writer.Key("leaf_vector_end");
WriteContiguousArray(writer, tree.leaf_vector_end_);
writer.Key("matching_categories");
WriteContiguousArray(writer, tree.matching_categories_);
writer.Key("matching_categories_offset");
Expand All @@ -131,10 +133,8 @@ void SerializeTreeToJSON(WriterType& writer, const Tree<ThresholdType, LeafOutpu

writer.EndObject();

// Sanity check
// Basic checks
TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.leaf_vector_offset_.Size());
TREELITE_CHECK_EQ(tree.leaf_vector_offset_.Back(), tree.leaf_vector_.Size());
TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
}
Expand Down