diff --git a/include/treelite/tree.h b/include/treelite/tree.h index e7f39315..50b21307 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -278,7 +278,11 @@ class Tree { // vector of nodes ContiguousArray nodes_; ContiguousArray leaf_vector_; - ContiguousArray leaf_vector_offset_; + // Map nid to the start and end index in leaf_vector_ + // We could use std::pair, but it is not POD, so easier to use two vectors + // here + ContiguousArray leaf_vector_begin_; + ContiguousArray leaf_vector_end_; ContiguousArray matching_categories_; ContiguousArray matching_categories_offset_; @@ -369,8 +373,8 @@ class Tree { * \param nid ID of node being queried */ inline std::vector LeafVector(int nid) const { - const std::size_t offset_begin = leaf_vector_offset_.at(nid); - const std::size_t offset_end = leaf_vector_offset_.at(nid + 1); + const std::size_t offset_begin = leaf_vector_begin_.at(nid); + const std::size_t offset_end = leaf_vector_end_.at(nid); if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) { // Return empty vector, to indicate the lack of leaf vector return std::vector(); @@ -385,7 +389,7 @@ class Tree { * \param nid ID of node being queried */ inline bool HasLeafVector(int nid) const { - return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1); + return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid); } /*! * \brief get threshold of the node diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 05e38fa1..041f51d4 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -485,7 +485,8 @@ Tree::Clone() const { tree.num_nodes = num_nodes; tree.nodes_ = nodes_.Clone(); tree.leaf_vector_ = leaf_vector_.Clone(); - tree.leaf_vector_offset_ = leaf_vector_offset_.Clone(); + tree.leaf_vector_begin_ = leaf_vector_begin_.Clone(); + tree.leaf_vector_end_ = leaf_vector_end_.Clone(); tree.matching_categories_ = matching_categories_.Clone(); tree.matching_categories_offset_ = matching_categories_offset_.Clone(); return tree; @@ -501,7 +502,7 @@ Tree::GetFormatStringForNode() { } } -constexpr std::size_t kNumFramePerTree = 6; +constexpr std::size_t kNumFramePerTree = 7; template template @@ -512,7 +513,8 @@ Tree::SerializeTemplate( scalar_handler(&num_nodes); composite_array_handler(&nodes_, GetFormatStringForNode()); primitive_array_handler(&leaf_vector_); - primitive_array_handler(&leaf_vector_offset_); + primitive_array_handler(&leaf_vector_begin_); + primitive_array_handler(&leaf_vector_end_); primitive_array_handler(&matching_categories_); primitive_array_handler(&matching_categories_offset_); } @@ -528,7 +530,8 @@ Tree::DeserializeTemplate( throw std::runtime_error("Could not load the correct number of nodes"); } array_handler(&leaf_vector_); - array_handler(&leaf_vector_offset_); + array_handler(&leaf_vector_begin_); + array_handler(&leaf_vector_end_); array_handler(&matching_categories_); array_handler(&matching_categories_offset_); } @@ -614,7 +617,8 @@ Tree::AllocNode() { throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes"); } for (int nid = nd; nid < num_nodes; ++nid) { - leaf_vector_offset_.PushBack(leaf_vector_offset_.Back()); + leaf_vector_begin_.PushBack(0); + leaf_vector_end_.PushBack(0); matching_categories_offset_.PushBack(matching_categories_offset_.Back()); nodes_.Resize(nodes_.Size() + 1); nodes_.Back().Init(); @@ -627,7 +631,8 @@ inline void Tree::Init() { num_nodes = 1; leaf_vector_.Clear(); - leaf_vector_offset_.Resize(2, 0); + leaf_vector_begin_.Resize(1, {}); + leaf_vector_end_.Resize(1, {}); matching_categories_.Clear(); matching_categories_offset_.Resize(2, 0); nodes_.Resize(1); @@ -737,24 +742,12 @@ template inline void Tree::SetLeafVector( int nid, const std::vector& node_leaf_vector) { - const std::size_t end_oft = leaf_vector_offset_.Back(); - const std::size_t new_end_oft = end_oft + node_leaf_vector.size(); - if (end_oft != leaf_vector_.Size()) { - throw std::runtime_error("Invariant violated"); - } - if (!std::all_of(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(), - [end_oft](std::size_t x) { return (x == end_oft); })) { - throw std::runtime_error("Invariant violated"); - } - // Hopefully we won't have to move any element as we add leaf vector elements for node nid + std::size_t begin = leaf_vector_.Size(); + std::size_t end = begin + node_leaf_vector.size(); leaf_vector_.Extend(node_leaf_vector); - if (new_end_oft != leaf_vector_.Size()) { - throw std::runtime_error("Invariant violated"); - } - std::for_each(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(), - [new_end_oft](std::size_t& x) { x = new_end_oft; }); - - Node& node = nodes_.at(nid); + leaf_vector_begin_[nid] = begin; + leaf_vector_end_[nid] = end; + Node &node = nodes_.at(nid); node.cleft_ = -1; node.cright_ = -1; node.split_type_ = SplitFeatureType::kNone; diff --git a/src/json_serializer.cc b/src/json_serializer.cc index 92fcac21..55f19d29 100644 --- a/src/json_serializer.cc +++ b/src/json_serializer.cc @@ -116,8 +116,10 @@ void SerializeTreeToJSON(WriterType& writer, const Tree