diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index ed3f10571edf..743bf0a66ce0 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -61,6 +61,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/tree_model.o \ $(PKGROOT)/src/tree/tree_updater.o \ + $(PKGROOT)/src/tree/multi_target_tree_model.o \ $(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_colmaker.o \ $(PKGROOT)/src/tree/updater_prune.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 024ba1aa19b7..a32d2fd2e45d 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -60,6 +60,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/param.o \ $(PKGROOT)/src/tree/fit_stump.o \ $(PKGROOT)/src/tree/tree_model.o \ + $(PKGROOT)/src/tree/multi_target_tree_model.o \ $(PKGROOT)/src/tree/tree_updater.o \ $(PKGROOT)/src/tree/updater_approx.o \ $(PKGROOT)/src/tree/updater_colmaker.o \ diff --git a/include/xgboost/base.h b/include/xgboost/base.h index d12e71a3aa39..00fc7fb4ac63 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -110,11 +110,11 @@ using bst_bin_t = int32_t; // NOLINT */ using bst_row_t = std::size_t; // NOLINT /*! \brief Type for tree node index. */ -using bst_node_t = int32_t; // NOLINT +using bst_node_t = std::int32_t; // NOLINT /*! \brief Type for ranking group index. */ -using bst_group_t = uint32_t; // NOLINT -/*! \brief Type for indexing target variables. */ -using bst_target_t = std::size_t; // NOLINT +using bst_group_t = std::uint32_t; // NOLINT +/*! \brief Type for indexing into output targets. */ +using bst_target_t = std::uint32_t; // NOLINT namespace detail { /*! \brief Implementation of gradient statistics pair. Template specialisation diff --git a/include/xgboost/multi_target_tree_model.h b/include/xgboost/multi_target_tree_model.h new file mode 100644 index 000000000000..1ad7d6bf6a1c --- /dev/null +++ b/include/xgboost/multi_target_tree_model.h @@ -0,0 +1,96 @@ +/** + * Copyright 2023 by XGBoost contributors + * + * \brief Core data structure for multi-target trees. + */ +#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_ +#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_ +#include // for bst_node_t, bst_target_t, bst_feature_t +#include // for Context +#include // for VectorView +#include // for Model +#include // for Span + +#include // for uint8_t +#include // for size_t +#include // for vector + +namespace xgboost { +struct TreeParam; +/** + * \brief Tree structure for multi-target model. + */ +class MultiTargetTree : public Model { + public: + static bst_node_t constexpr InvalidNodeId() { return -1; } + + private: + TreeParam const* param_; + std::vector left_; + std::vector right_; + std::vector parent_; + std::vector split_index_; + std::vector default_left_; + std::vector split_conds_; + std::vector weights_; + + [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) const { + auto beg = nidx * this->NumTarget(); + auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); + return linalg::MakeTensorView(Context::kCpuId, v, v.size()); + } + [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) { + auto beg = nidx * this->NumTarget(); + auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); + return linalg::MakeTensorView(Context::kCpuId, v, v.size()); + } + + public: + explicit MultiTargetTree(TreeParam const* param); + /** + * \brief Set the weight for a leaf. + */ + void SetLeaf(bst_node_t nidx, linalg::VectorView weight); + /** + * \brief Expand a leaf into split node. + */ + void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight); + + [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); } + [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); } + [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); } + [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); } + + [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; } + [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; } + [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; } + [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const { + return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); + } + + [[nodiscard]] bst_target_t NumTarget() const; + + [[nodiscard]] std::size_t Size() const; + + [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const { + bst_node_t depth{0}; + while (Parent(nidx) != InvalidNodeId()) { + ++depth; + nidx = Parent(nidx); + } + return depth; + } + + [[nodiscard]] linalg::VectorView LeafValue(bst_node_t nidx) const { + CHECK(IsLeaf(nidx)); + return this->NodeWeight(nidx); + } + + void LoadModel(Json const& in) override; + void SaveModel(Json* out) const override; +}; +} // namespace xgboost +#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_ diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 70c71cac1ad9..f646140dcb20 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2014-2022 by Contributors +/** + * Copyright 2014-2023 by Contributors * \file tree_model.h * \brief model structure for tree * \author Tianqi Chen @@ -9,60 +9,57 @@ #include #include - #include #include -#include #include +#include // for VectorView +#include #include +#include // for MultiTargetTree +#include +#include #include -#include +#include // for make_unique +#include #include -#include -#include #include -#include +#include namespace xgboost { - -struct PathElement; // forward declaration - class Json; + // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should // not be configured by users. /*! \brief meta parameters of the tree */ struct TreeParam : public dmlc::Parameter { /*! \brief (Deprecated) number of start root */ - int deprecated_num_roots; + int deprecated_num_roots{1}; /*! \brief total number of nodes */ - int num_nodes; + int num_nodes{1}; /*!\brief number of deleted nodes */ - int num_deleted; + int num_deleted{0}; /*! \brief maximum depth, this is a statistics of the tree */ - int deprecated_max_depth; + int deprecated_max_depth{0}; /*! \brief number of features used for tree construction */ - bst_feature_t num_feature; + bst_feature_t num_feature{0}; /*! * \brief leaf vector size, used for vector tree * used to store more than one dimensional information in tree */ - int size_leaf_vector; + bst_target_t size_leaf_vector{1}; /*! \brief reserved part, make sure alignment works for 64bit */ int reserved[31]; /*! \brief constructor */ TreeParam() { // assert compact alignment - static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), - "TreeParam: 64 bit align"); - std::memset(this, 0, sizeof(TreeParam)); - num_nodes = 1; - deprecated_num_roots = 1; + static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align"); + std::memset(reserved, 0, sizeof(reserved)); } // Swap byte order for all fields. Useful for transporting models between machines with different // endianness (big endian vs little endian) - inline TreeParam ByteSwap() const { + [[nodiscard]] TreeParam ByteSwap() const { TreeParam x = *this; dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1); dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1); @@ -80,17 +77,18 @@ struct TreeParam : public dmlc::Parameter { // other arguments are set by the algorithm. DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1); DMLC_DECLARE_FIELD(num_feature) + .set_default(0) .describe("Number of features used in tree construction."); - DMLC_DECLARE_FIELD(num_deleted); - DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) + DMLC_DECLARE_FIELD(num_deleted).set_default(0); + DMLC_DECLARE_FIELD(size_leaf_vector) + .set_lower_bound(0) + .set_default(1) .describe("Size of leaf vector, reserved for vector tree"); } bool operator==(const TreeParam& b) const { - return num_nodes == b.num_nodes && - num_deleted == b.num_deleted && - num_feature == b.num_feature && - size_leaf_vector == b.size_leaf_vector; + return num_nodes == b.num_nodes && num_deleted == b.num_deleted && + num_feature == b.num_feature && size_leaf_vector == b.size_leaf_vector; } }; @@ -114,7 +112,7 @@ struct RTreeNodeStat { } // Swap byte order for all fields. Useful for transporting models between machines with different // endianness (big endian vs little endian) - inline RTreeNodeStat ByteSwap() const { + [[nodiscard]] RTreeNodeStat ByteSwap() const { RTreeNodeStat x = *this; dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1); dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1); @@ -124,16 +122,45 @@ struct RTreeNodeStat { } }; -/*! +/** + * \brief Helper for defining copyable data structure that contains unique pointers. + */ +template +class CopyUniquePtr { + std::unique_ptr ptr_{nullptr}; + + public: + CopyUniquePtr() = default; + CopyUniquePtr(CopyUniquePtr const& that) { + ptr_.reset(nullptr); + if (that.ptr_) { + ptr_ = std::make_unique(*that); + } + } + T* get() const noexcept { return ptr_.get(); } // NOLINT + + T& operator*() { return *ptr_; } + T* operator->() noexcept { return this->get(); } + + T const& operator*() const { return *ptr_; } + T const* operator->() const noexcept { return this->get(); } + + explicit operator bool() const { return static_cast(ptr_); } + bool operator!() const { return !ptr_; } + void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT +}; + +/** * \brief define regression tree to be the most common tree model. + * * This is the data structure used in xgboost's major tree models. */ class RegTree : public Model { public: using SplitCondT = bst_float; - static constexpr bst_node_t kInvalidNodeId {-1}; + static constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()}; static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits::max(); - static constexpr bst_node_t kRoot { 0 }; + static constexpr bst_node_t kRoot{0}; /*! \brief tree node */ class Node { @@ -151,51 +178,51 @@ class RegTree : public Model { } /*! \brief index of left child */ - XGBOOST_DEVICE int LeftChild() const { + XGBOOST_DEVICE [[nodiscard]] int LeftChild() const { return this->cleft_; } /*! \brief index of right child */ - XGBOOST_DEVICE int RightChild() const { + XGBOOST_DEVICE [[nodiscard]] int RightChild() const { return this->cright_; } /*! \brief index of default child when feature is missing */ - XGBOOST_DEVICE int DefaultChild() const { + XGBOOST_DEVICE [[nodiscard]] int DefaultChild() const { return this->DefaultLeft() ? this->LeftChild() : this->RightChild(); } /*! \brief feature index of split condition */ - XGBOOST_DEVICE unsigned SplitIndex() const { + XGBOOST_DEVICE [[nodiscard]] unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } /*! \brief when feature is unknown, whether goes to left child */ - XGBOOST_DEVICE bool DefaultLeft() const { + XGBOOST_DEVICE [[nodiscard]] bool DefaultLeft() const { return (sindex_ >> 31) != 0; } /*! \brief whether current node is leaf node */ - XGBOOST_DEVICE bool IsLeaf() const { + XGBOOST_DEVICE [[nodiscard]] bool IsLeaf() const { return cleft_ == kInvalidNodeId; } /*! \return get leaf value of leaf node */ - XGBOOST_DEVICE bst_float LeafValue() const { + XGBOOST_DEVICE [[nodiscard]] float LeafValue() const { return (this->info_).leaf_value; } /*! \return get split condition of the node */ - XGBOOST_DEVICE SplitCondT SplitCond() const { + XGBOOST_DEVICE [[nodiscard]] SplitCondT SplitCond() const { return (this->info_).split_cond; } /*! \brief get parent of the node */ - XGBOOST_DEVICE int Parent() const { + XGBOOST_DEVICE [[nodiscard]] int Parent() const { return parent_ & ((1U << 31) - 1); } /*! \brief whether current node is left child */ - XGBOOST_DEVICE bool IsLeftChild() const { + XGBOOST_DEVICE [[nodiscard]] bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; } /*! \brief whether this node is deleted */ - XGBOOST_DEVICE bool IsDeleted() const { + XGBOOST_DEVICE [[nodiscard]] bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; } /*! \brief whether current node is root */ - XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } + XGBOOST_DEVICE [[nodiscard]] bool IsRoot() const { return parent_ == kInvalidNodeId; } /*! * \brief set the left child * \param nid node id to right child @@ -252,7 +279,7 @@ class RegTree : public Model { info_.leaf_value == b.info_.leaf_value; } - inline Node ByteSwap() const { + [[nodiscard]] Node ByteSwap() const { Node x = *this; dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1); dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1); @@ -312,19 +339,28 @@ class RegTree : public Model { /*! \brief model parameter */ TreeParam param; - /*! \brief constructor */ RegTree() { - param.num_nodes = 1; - param.num_deleted = 0; + param.Init(Args{}); nodes_.resize(param.num_nodes); stats_.resize(param.num_nodes); split_types_.resize(param.num_nodes, FeatureType::kNumerical); split_categories_segments_.resize(param.num_nodes); - for (int i = 0; i < param.num_nodes; i ++) { + for (int i = 0; i < param.num_nodes; i++) { nodes_[i].SetLeaf(0.0f); nodes_[i].SetParent(kInvalidNodeId); } } + /** + * \brief Constructor that initializes the tree model with shape. + */ + explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} { + param.num_feature = n_features; + param.size_leaf_vector = n_targets; + if (n_targets > 1) { + this->p_mt_tree_.reset(new MultiTargetTree{¶m}); + } + } + /*! \brief get node given nid */ Node& operator[](int nid) { return nodes_[nid]; @@ -335,17 +371,17 @@ class RegTree : public Model { } /*! \brief get const reference to nodes */ - const std::vector& GetNodes() const { return nodes_; } + [[nodiscard]] const std::vector& GetNodes() const { return nodes_; } /*! \brief get const reference to stats */ - const std::vector& GetStats() const { return stats_; } + [[nodiscard]] const std::vector& GetStats() const { return stats_; } /*! \brief get node statistics given nid */ RTreeNodeStat& Stat(int nid) { return stats_[nid]; } /*! \brief get node statistics given nid */ - const RTreeNodeStat& Stat(int nid) const { + [[nodiscard]] const RTreeNodeStat& Stat(int nid) const { return stats_[nid]; } @@ -398,7 +434,7 @@ class RegTree : public Model { * * \param b The other tree. */ - bool Equal(const RegTree& b) const; + [[nodiscard]] bool Equal(const RegTree& b) const; /** * \brief Expands a leaf node into two additional leaf nodes. @@ -424,6 +460,11 @@ class RegTree : public Model { float right_sum, bst_node_t leaf_right_child = kInvalidNodeId); + void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, + linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight); + /** * \brief Expands a leaf node with categories * @@ -445,15 +486,27 @@ class RegTree : public Model { bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum); - bool HasCategoricalSplit() const { + [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); } + /** + * \brief Whether this is a multi-target tree. + */ + [[nodiscard]] bool IsMultiTarget() const { return static_cast(p_mt_tree_); } + [[nodiscard]] bst_target_t NumTargets() const { return param.size_leaf_vector; } + [[nodiscard]] auto GetMultiTargetTree() const { + CHECK(IsMultiTarget()); + return p_mt_tree_.get(); + } /*! * \brief get current depth * \param nid node id */ - int GetDepth(int nid) const { + [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Depth(nid); + } int depth = 0; while (!nodes_[nid].IsRoot()) { ++depth; @@ -461,12 +514,16 @@ class RegTree : public Model { } return depth; } + void SetLeaf(bst_node_t nidx, linalg::VectorView weight) { + CHECK(IsMultiTarget()); + return this->p_mt_tree_->SetLeaf(nidx, weight); + } /*! * \brief get maximum depth * \param nid node id */ - int MaxDepth(int nid) const { + [[nodiscard]] int MaxDepth(int nid) const { if (nodes_[nid].IsLeaf()) return 0; return std::max(MaxDepth(nodes_[nid].LeftChild())+1, MaxDepth(nodes_[nid].RightChild())+1); @@ -480,13 +537,13 @@ class RegTree : public Model { } /*! \brief number of extra nodes besides the root */ - int NumExtraNodes() const { + [[nodiscard]] int NumExtraNodes() const { return param.num_nodes - 1 - param.num_deleted; } /* \brief Count number of leaves in tree. */ - bst_node_t GetNumLeaves() const; - bst_node_t GetNumSplitNodes() const; + [[nodiscard]] bst_node_t GetNumLeaves() const; + [[nodiscard]] bst_node_t GetNumSplitNodes() const; /*! * \brief dense feature vector that can be taken by RegTree @@ -513,20 +570,20 @@ class RegTree : public Model { * \brief returns the size of the feature vector * \return the size of the feature vector */ - size_t Size() const; + [[nodiscard]] size_t Size() const; /*! * \brief get ith value * \param i feature index. * \return the i-th feature value */ - bst_float GetFvalue(size_t i) const; + [[nodiscard]] bst_float GetFvalue(size_t i) const; /*! * \brief check whether i-th entry is missing * \param i feature index. * \return whether i-th value is missing. */ - bool IsMissing(size_t i) const; - bool HasMissing() const; + [[nodiscard]] bool IsMissing(size_t i) const; + [[nodiscard]] bool HasMissing() const; private: @@ -557,56 +614,123 @@ class RegTree : public Model { * \param format the format to dump the model in * \return the string of dumped model */ - std::string DumpModel(const FeatureMap& fmap, - bool with_stats, - std::string format) const; + [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats, + std::string format) const; /*! * \brief Get split type for a node. * \param nidx Index of node. * \return The type of this split. For leaf node it's always kNumerical. */ - FeatureType NodeSplitType(bst_node_t nidx) const { - return split_types_.at(nidx); - } + [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); } /*! * \brief Get split types for all nodes. */ - std::vector const &GetSplitTypes() const { return split_types_; } - common::Span GetSplitCategories() const { return split_categories_; } + [[nodiscard]] std::vector const& GetSplitTypes() const { + return split_types_; + } + [[nodiscard]] common::Span GetSplitCategories() const { + return split_categories_; + } /*! * \brief Get the bit storage for categories */ - common::Span NodeCats(bst_node_t nidx) const { + [[nodiscard]] common::Span NodeCats(bst_node_t nidx) const { auto node_ptr = GetCategoriesMatrix().node_ptr; auto categories = GetCategoriesMatrix().categories; auto segment = node_ptr[nidx]; auto node_cats = categories.subspan(segment.beg, segment.size); return node_cats; } - auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } - - // The fields of split_categories_segments_[i] are set such that - // the range split_categories_[beg:(beg+size)] stores the bitset for - // the matching categories for the i-th node. - struct Segment { - size_t beg {0}; - size_t size {0}; - }; + [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; } + /** + * \brief CSR-like matrix for categorical splits. + * + * The fields of split_categories_segments_[i] are set such that the range + * node_ptr[beg:(beg+size)] stores the bitset for the matching categories for the + * i-th node. + */ struct CategoricalSplitMatrix { + struct Segment { + std::size_t beg{0}; + std::size_t size{0}; + }; common::Span split_type; common::Span categories; common::Span node_ptr; }; - CategoricalSplitMatrix GetCategoriesMatrix() const { + [[nodiscard]] CategoricalSplitMatrix GetCategoriesMatrix() const { CategoricalSplitMatrix view; view.split_type = common::Span(this->GetSplitTypes()); view.categories = this->GetSplitCategories(); - view.node_ptr = common::Span(split_categories_segments_); + view.node_ptr = common::Span(split_categories_segments_); return view; } + [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->SplitIndex(nidx); + } + return (*this)[nidx].SplitIndex(); + } + [[nodiscard]] float SplitCond(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->SplitCond(nidx); + } + return (*this)[nidx].SplitCond(); + } + [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->DefaultLeft(nidx); + } + return (*this)[nidx].DefaultLeft(); + } + [[nodiscard]] bool IsRoot(bst_node_t nidx) const { + if (IsMultiTarget()) { + return nidx == kRoot; + } + return (*this)[nidx].IsRoot(); + } + [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->IsLeaf(nidx); + } + return (*this)[nidx].IsLeaf(); + } + [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Parent(nidx); + } + return (*this)[nidx].Parent(); + } + [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->LeftChild(nidx); + } + return (*this)[nidx].LeftChild(); + } + [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + return this->p_mt_tree_->RightChild(nidx); + } + return (*this)[nidx].RightChild(); + } + [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const { + if (IsMultiTarget()) { + CHECK_NE(nidx, kRoot); + auto p = this->p_mt_tree_->Parent(nidx); + return nidx == this->p_mt_tree_->LeftChild(p); + } + return (*this)[nidx].IsLeftChild(); + } + [[nodiscard]] bst_node_t Size() const { + if (IsMultiTarget()) { + return this->p_mt_tree_->Size(); + } + return this->nodes_.size(); + } + private: template void LoadCategoricalSplit(Json const& in); @@ -622,8 +746,9 @@ class RegTree : public Model { // Categories for each internal node. std::vector split_categories_; // Ptr to split categories of each node. - std::vector split_categories_segments_; - + std::vector split_categories_segments_; + // ptr to multi-target tree with vector leaf. + CopyUniquePtr p_mt_tree_; // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize bst_node_t AllocNode() { @@ -703,5 +828,10 @@ inline bool RegTree::FVec::IsMissing(size_t i) const { inline bool RegTree::FVec::HasMissing() const { return has_missing_; } + +// Multi-target tree not yet implemented error +inline StringView MTNotImplemented() { + return " support for multi-target tree is not yet implemented."; +} } // namespace xgboost #endif // XGBOOST_TREE_MODEL_H_ diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 030070d9aecd..55c0ecf202f8 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2017 XGBoost contributors +/** + * Copyright 2017-2023 by XGBoost contributors */ #ifndef XGBOOST_USE_CUDA @@ -179,7 +179,6 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t -template class HostDeviceVector; #if defined(__APPLE__) || defined(__EMSCRIPTEN__) /* diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index a5c5dbf8fa1b..1fa9a3b2200c 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -1,7 +1,6 @@ -/*! - * Copyright 2017 XGBoost contributors +/** + * Copyright 2017-2023 by XGBoost contributors */ - #include #include @@ -412,7 +411,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; -template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; #if defined(__APPLE__) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 35daf701c9d3..caf4b6bb4ea0 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2017-2021 by Contributors +/** + * Copyright 2017-2023 by XGBoost Contributors */ #include #include @@ -25,9 +25,7 @@ #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" -namespace xgboost { -namespace predictor { - +namespace xgboost::predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); struct TreeView { @@ -35,12 +33,11 @@ struct TreeView { common::Span d_tree; XGBOOST_DEVICE - TreeView(size_t tree_begin, size_t tree_idx, - common::Span d_nodes, + TreeView(size_t tree_begin, size_t tree_idx, common::Span d_nodes, common::Span d_tree_segments, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories) { auto begin = d_tree_segments[tree_idx - tree_begin]; auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] - @@ -255,7 +252,7 @@ PredictLeafKernel(Data data, common::Span d_nodes, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories, size_t tree_begin, size_t tree_end, size_t num_features, @@ -290,7 +287,7 @@ PredictKernel(Data data, common::Span d_nodes, common::Span d_tree_group, common::Span d_tree_split_types, common::Span d_cat_tree_segments, - common::Span d_cat_node_segments, + common::Span d_cat_node_segments, common::Span d_categories, size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, size_t entry_start, bool use_shared, int num_group, float missing) { @@ -334,7 +331,7 @@ class DeviceModel { // Pointer to each tree, segmenting the node array. HostDeviceVector categories_tree_segments; // Pointer to each node, segmenting categories array. - HostDeviceVector categories_node_segments; + HostDeviceVector categories_node_segments; HostDeviceVector categories; size_t tree_beg_; // NOLINT @@ -400,9 +397,9 @@ class DeviceModel { h_split_cat_segments.push_back(h_categories.size()); } - categories_node_segments = - HostDeviceVector(h_tree_segments.back(), {}, gpu_id); - std::vector &h_categories_node_segments = + categories_node_segments = HostDeviceVector( + h_tree_segments.back(), {}, gpu_id); + std::vector& h_categories_node_segments = categories_node_segments.HostVector(); for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); @@ -542,10 +539,10 @@ void ExtractPaths( if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types), common::IsCatOp{})) { dh::PinnedMemory pinned; - auto h_max_cat = pinned.GetSpan(1); + auto h_max_cat = pinned.GetSpan(1); auto max_elem_it = dh::MakeTransformIterator( dh::tbegin(d_cat_node_segments), - [] __device__(RegTree::Segment seg) { return seg.size; }); + [] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; }); size_t max_cat_it = thrust::max_element(thrust::device, max_elem_it, max_elem_it + d_cat_node_segments.size()) - @@ -1028,5 +1025,4 @@ XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") .describe("Make predictions using GPU.") .set_body([](Context const* ctx) { return new GPUPredictor(ctx); }); -} // namespace predictor -} // namespace xgboost +} // namespace xgboost::predictor diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 82efff2c77ac..ad0253d22be4 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -71,10 +71,7 @@ void FitStump(Context const* ctx, HostDeviceVector const& gpair, auto n_samples = gpair.Size() / n_targets; gpair.SetDevice(ctx->gpu_id); - linalg::TensorView gpair_t{ - ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(), - {n_samples, n_targets}, - ctx->gpu_id}; + auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets); ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView()) : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id)); } diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 4e64cbd7533b..50b90f244aad 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -12,7 +12,7 @@ #include "../../common/hist_util.h" #include "../../data/gradient_index.h" #include "expand_entry.h" -#include "xgboost/tree_model.h" +#include "xgboost/tree_model.h" // for RegTree namespace xgboost { namespace tree { @@ -175,8 +175,8 @@ class HistogramBuilder { auto this_local = hist_local_worker_[entry.nid]; common::CopyHist(this_local, this_hist, r.begin(), r.end()); - if (!(*p_tree)[entry.nid].IsRoot()) { - const size_t parent_id = (*p_tree)[entry.nid].Parent(); + if (!p_tree->IsRoot(entry.nid)) { + const size_t parent_id = p_tree->Parent(entry.nid); const int subtraction_node_id = nodes_for_subtraction_trick[node].nid; auto parent_hist = this->hist_local_worker_[parent_id]; auto sibling_hist = this->hist_[subtraction_node_id]; @@ -213,8 +213,8 @@ class HistogramBuilder { // Merging histograms from each thread into once this->buffer_.ReduceHist(node, r.begin(), r.end()); - if (!(*p_tree)[entry.nid].IsRoot()) { - auto const parent_id = (*p_tree)[entry.nid].Parent(); + if (!p_tree->IsRoot(entry.nid)) { + auto const parent_id = p_tree->Parent(entry.nid); auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid; auto parent_hist = this->hist_[parent_id]; auto sibling_hist = this->hist_[subtraction_node_id]; @@ -237,10 +237,10 @@ class HistogramBuilder { common::ParallelFor2d( space, this->n_threads_, [&](size_t node, common::Range1d r) { const auto &entry = nodes[node]; - if (!((*p_tree)[entry.nid].IsLeftChild())) { + if (!(p_tree->IsLeftChild(entry.nid))) { auto this_hist = this->hist_[entry.nid]; - if (!(*p_tree)[entry.nid].IsRoot()) { + if (!p_tree->IsRoot(entry.nid)) { const int subtraction_node_id = subtraction_nodes[node].nid; auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()]; auto sibling_hist = hist_[subtraction_node_id]; @@ -285,7 +285,7 @@ class HistogramBuilder { std::sort(merged_node_ids.begin(), merged_node_ids.end()); int n_left = 0; for (auto const &nid : merged_node_ids) { - if ((*p_tree)[nid].IsLeftChild()) { + if (p_tree->IsLeftChild(nid)) { this->hist_.AddHistRow(nid); (*starting_index) = std::min(nid, (*starting_index)); n_left++; @@ -293,7 +293,7 @@ class HistogramBuilder { } } for (auto const &nid : merged_node_ids) { - if (!((*p_tree)[nid].IsLeftChild())) { + if (!(p_tree->IsLeftChild(nid))) { this->hist_.AddHistRow(nid); this->hist_local_worker_.AddHistRow(nid); } diff --git a/src/tree/io_utils.h b/src/tree/io_utils.h new file mode 100644 index 000000000000..a0d31cc83bd3 --- /dev/null +++ b/src/tree/io_utils.h @@ -0,0 +1,65 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TREE_IO_UTILS_H_ +#define XGBOOST_TREE_IO_UTILS_H_ +#include // for string +#include // for enable_if_t, is_same, conditional_t +#include // for vector + +#include "xgboost/json.h" // for Json + +namespace xgboost { +template +using FloatArrayT = std::conditional_t; +template +using U8ArrayT = std::conditional_t; +template +using I32ArrayT = std::conditional_t; +template +using I64ArrayT = std::conditional_t; +template +using IndexArrayT = std::conditional_t, I32ArrayT>; + +// typed array, not boolean +template +std::enable_if_t::value && !std::is_same::value, T> GetElem( + std::vector const& arr, size_t i) { + return arr[i]; +} +// typed array boolean +template +std::enable_if_t::value && std::is_same::value && + std::is_same::value, + bool> +GetElem(std::vector const& arr, size_t i) { + return arr[i] == 1; +} +// json array +template +std::enable_if_t< + std::is_same::value, + std::conditional_t::value, int64_t, + std::conditional_t::value, bool, float>>> +GetElem(std::vector const& arr, size_t i) { + if (std::is_same::value && !IsA(arr[i])) { + return get(arr[i]) == 1; + } + return get(arr[i]); +} + +namespace tree_field { +inline std::string const kLossChg{"loss_changes"}; +inline std::string const kSumHess{"sum_hessian"}; +inline std::string const kBaseWeight{"base_weights"}; + +inline std::string const kSplitIdx{"split_indices"}; +inline std::string const kSplitCond{"split_conditions"}; +inline std::string const kDftLeft{"default_left"}; + +inline std::string const kParent{"parents"}; +inline std::string const kLeft{"left_children"}; +inline std::string const kRight{"right_children"}; +} // namespace tree_field +} // namespace xgboost +#endif // XGBOOST_TREE_IO_UTILS_H_ diff --git a/src/tree/multi_target_tree_model.cc b/src/tree/multi_target_tree_model.cc new file mode 100644 index 000000000000..bccc1967e9cc --- /dev/null +++ b/src/tree/multi_target_tree_model.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include "xgboost/multi_target_tree_model.h" + +#include // for copy_n +#include // for size_t +#include // for int32_t, uint8_t +#include // for numeric_limits +#include // for string_view +#include // for move +#include // for vector + +#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ... +#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t +#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ... +#include "xgboost/logging.h" +#include "xgboost/tree_model.h" // for TreeParam + +namespace xgboost { +MultiTargetTree::MultiTargetTree(TreeParam const* param) + : param_{param}, + left_(1ul, InvalidNodeId()), + right_(1ul, InvalidNodeId()), + parent_(1ul, InvalidNodeId()), + split_index_(1ul, 0), + default_left_(1ul, 0), + split_conds_(1ul, std::numeric_limits::quiet_NaN()), + weights_(param->size_leaf_vector, std::numeric_limits::quiet_NaN()) { + CHECK_GT(param_->size_leaf_vector, 1); +} + +template +void LoadModelImpl(Json const& in, std::vector* p_weights, std::vector* p_lefts, + std::vector* p_rights, std::vector* p_parents, + std::vector* p_conds, std::vector* p_fidx, + std::vector* p_dft_left) { + namespace tf = tree_field; + + auto get_float = [&](std::string_view name, std::vector* p_out) { + auto& values = get>(get(in).find(name)->second); + auto& out = *p_out; + out.resize(values.size()); + for (std::size_t i = 0; i < values.size(); ++i) { + out[i] = GetElem(values, i); + } + }; + get_float(tf::kBaseWeight, p_weights); + get_float(tf::kSplitCond, p_conds); + + auto get_nidx = [&](std::string_view name, std::vector* p_nidx) { + auto& nidx = get>(get(in).find(name)->second); + auto& out_nidx = *p_nidx; + out_nidx.resize(nidx.size()); + for (std::size_t i = 0; i < nidx.size(); ++i) { + out_nidx[i] = GetElem(nidx, i); + } + }; + get_nidx(tf::kLeft, p_lefts); + get_nidx(tf::kRight, p_rights); + get_nidx(tf::kParent, p_parents); + + auto const& splits = get const>(in[tf::kSplitIdx]); + p_fidx->resize(splits.size()); + auto& out_fidx = *p_fidx; + for (std::size_t i = 0; i < splits.size(); ++i) { + out_fidx[i] = GetElem(splits, i); + } + + auto const& dft_left = get const>(in[tf::kDftLeft]); + auto& out_dft_l = *p_dft_left; + out_dft_l.resize(dft_left.size()); + for (std::size_t i = 0; i < dft_left.size(); ++i) { + out_dft_l[i] = GetElem(dft_left, i); + } +} + +void MultiTargetTree::LoadModel(Json const& in) { + namespace tf = tree_field; + bool typed = IsA(in[tf::kBaseWeight]); + bool feature_is_64 = IsA(in[tf::kSplitIdx]); + + if (typed && feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else if (typed && !feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else if (!typed && feature_is_64) { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } else { + LoadModelImpl(in, &weights_, &left_, &right_, &parent_, &split_conds_, + &split_index_, &default_left_); + } +} + +void MultiTargetTree::SaveModel(Json* p_out) const { + CHECK(p_out); + auto& out = *p_out; + + auto n_nodes = param_->num_nodes; + + // nodes + I32Array lefts(n_nodes); + I32Array rights(n_nodes); + I32Array parents(n_nodes); + F32Array conds(n_nodes); + U8Array default_left(n_nodes); + F32Array weights(n_nodes * this->NumTarget()); + + auto save_tree = [&](auto* p_indices_array) { + auto& indices_array = *p_indices_array; + for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) { + CHECK_LT(nidx, left_.size()); + lefts.Set(nidx, left_[nidx]); + CHECK_LT(nidx, right_.size()); + rights.Set(nidx, right_[nidx]); + CHECK_LT(nidx, parent_.size()); + parents.Set(nidx, parent_[nidx]); + CHECK_LT(nidx, split_index_.size()); + indices_array.Set(nidx, split_index_[nidx]); + conds.Set(nidx, split_conds_[nidx]); + default_left.Set(nidx, default_left_[nidx]); + + auto in_weight = this->NodeWeight(nidx); + auto weight_out = common::Span(weights.GetArray()) + .subspan(nidx * this->NumTarget(), this->NumTarget()); + CHECK_EQ(in_weight.Size(), weight_out.size()); + std::copy_n(in_weight.Values().data(), in_weight.Size(), weight_out.data()); + } + }; + + namespace tf = tree_field; + + if (this->param_->num_feature > + static_cast(std::numeric_limits::max())) { + I64Array indices_64(n_nodes); + save_tree(&indices_64); + out[tf::kSplitIdx] = std::move(indices_64); + } else { + I32Array indices_32(n_nodes); + save_tree(&indices_32); + out[tf::kSplitIdx] = std::move(indices_32); + } + + out[tf::kBaseWeight] = std::move(weights); + out[tf::kLeft] = std::move(lefts); + out[tf::kRight] = std::move(rights); + out[tf::kParent] = std::move(parents); + + out[tf::kSplitCond] = std::move(conds); + out[tf::kDftLeft] = std::move(default_left); +} + +void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView weight) { + CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented(); + auto const next_nidx = nidx + 1; + CHECK_EQ(weight.Size(), this->NumTarget()); + CHECK_GE(weights_.size(), next_nidx * weight.Size()); + auto out_weight = common::Span(weights_).subspan(nidx * weight.Size(), weight.Size()); + for (std::size_t i = 0; i < weight.Size(); ++i) { + out_weight[i] = weight(i); + } +} + +void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, + bool default_left, linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + CHECK(this->IsLeaf(nidx)); + CHECK_GE(parent_.size(), 1); + CHECK_EQ(parent_.size(), left_.size()); + CHECK_EQ(left_.size(), right_.size()); + + std::size_t n = param_->num_nodes + 2; + CHECK_LT(split_idx, this->param_->num_feature); + left_.resize(n, InvalidNodeId()); + right_.resize(n, InvalidNodeId()); + parent_.resize(n, InvalidNodeId()); + + auto left_child = parent_.size() - 2; + auto right_child = parent_.size() - 1; + + left_[nidx] = left_child; + right_[nidx] = right_child; + + if (nidx != 0) { + CHECK_NE(parent_[nidx], InvalidNodeId()); + } + + parent_[left_child] = nidx; + parent_[right_child] = nidx; + + split_index_.resize(n); + split_index_[nidx] = split_idx; + + split_conds_.resize(n); + split_conds_[nidx] = split_cond; + default_left_.resize(n); + default_left_[nidx] = static_cast(default_left); + + weights_.resize(n * this->NumTarget()); + auto p_weight = this->NodeWeight(nidx); + CHECK_EQ(p_weight.Size(), base_weight.Size()); + auto l_weight = this->NodeWeight(left_child); + CHECK_EQ(l_weight.Size(), left_weight.Size()); + auto r_weight = this->NodeWeight(right_child); + CHECK_EQ(r_weight.Size(), right_weight.Size()); + + for (std::size_t i = 0; i < base_weight.Size(); ++i) { + p_weight(i) = base_weight(i); + l_weight(i) = left_weight(i); + r_weight(i) = right_weight(i); + } +} + +bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; } +std::size_t MultiTargetTree::Size() const { return parent_.size(); } +} // namespace xgboost diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 55e37a9190d1..0891ec3b2aae 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,25 +1,27 @@ -/*! - * Copyright 2015-2022 by Contributors +/** + * Copyright 2015-2023 by Contributors * \file tree_model.cc * \brief model structure for tree */ -#include #include - -#include -#include +#include #include +#include -#include -#include #include #include -#include +#include +#include +#include -#include "param.h" -#include "../common/common.h" #include "../common/categorical.h" +#include "../common/common.h" #include "../predictor/predict_fn.h" +#include "io_utils.h" // GetElem +#include "param.h" +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" namespace xgboost { // register tree parameter @@ -729,12 +731,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") constexpr bst_node_t RegTree::kRoot; -std::string RegTree::DumpModel(const FeatureMap& fmap, - bool with_stats, - std::string format) const { - std::unique_ptr builder { - TreeGenerator::Create(format, fmap, with_stats) - }; +std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { + CHECK(!IsMultiTarget()); + std::unique_ptr builder{TreeGenerator::Create(format, fmap, with_stats)}; builder->BuildTree(*this); std::string result = builder->Str(); @@ -742,6 +741,7 @@ std::string RegTree::DumpModel(const FeatureMap& fmap, } bool RegTree::Equal(const RegTree& b) const { + CHECK(!IsMultiTarget()); if (NumExtraNodes() != b.NumExtraNodes()) { return false; } @@ -758,6 +758,7 @@ bool RegTree::Equal(const RegTree& b) const { } bst_node_t RegTree::GetNumLeaves() const { + CHECK(!IsMultiTarget()); bst_node_t leaves { 0 }; auto const& self = *this; this->WalkTree([&leaves, &self](bst_node_t nidx) { @@ -770,6 +771,7 @@ bst_node_t RegTree::GetNumLeaves() const { } bst_node_t RegTree::GetNumSplitNodes() const { + CHECK(!IsMultiTarget()); bst_node_t splits { 0 }; auto const& self = *this; this->WalkTree([&splits, &self](bst_node_t nidx) { @@ -787,6 +789,7 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child) { + CHECK(!IsMultiTarget()); int pleft = this->AllocNode(); int pright = this->AllocNode(); auto &node = nodes_[nid]; @@ -807,11 +810,31 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v this->split_types_.at(nid) = FeatureType::kNumerical; } +void RegTree::ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, + bool default_left, linalg::VectorView base_weight, + linalg::VectorView left_weight, + linalg::VectorView right_weight) { + CHECK(IsMultiTarget()); + CHECK_LT(split_index, this->param.num_feature); + CHECK(this->p_mt_tree_); + CHECK_GT(param.size_leaf_vector, 1); + + this->p_mt_tree_->Expand(nidx, split_index, split_cond, default_left, base_weight, left_weight, + right_weight); + + split_types_.resize(this->Size(), FeatureType::kNumerical); + split_categories_segments_.resize(this->Size()); + this->split_types_.at(nidx) = FeatureType::kNumerical; + + this->param.num_nodes = this->p_mt_tree_->Size(); +} + void RegTree::ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum) { + CHECK(!IsMultiTarget()); this->ExpandNode(nid, split_index, std::numeric_limits::quiet_NaN(), default_left, base_weight, left_leaf_weight, right_leaf_weight, loss_change, sum_hess, @@ -893,44 +916,17 @@ void RegTree::Save(dmlc::Stream* fo) const { } } } -// typed array, not boolean -template -std::enable_if_t::value && !std::is_same::value, T> GetElem( - std::vector const& arr, size_t i) { - return arr[i]; -} -// typed array boolean -template -std::enable_if_t::value && std::is_same::value && - std::is_same::value, - bool> -GetElem(std::vector const& arr, size_t i) { - return arr[i] == 1; -} -// json array -template -std::enable_if_t< - std::is_same::value, - std::conditional_t::value, int64_t, - std::conditional_t::value, bool, float>>> -GetElem(std::vector const& arr, size_t i) { - if (std::is_same::value && !IsA(arr[i])) { - return get(arr[i]) == 1; - } - return get(arr[i]); -} template void RegTree::LoadCategoricalSplit(Json const& in) { - using I64ArrayT = std::conditional_t; - using I32ArrayT = std::conditional_t; - - auto const& categories_segments = get(in["categories_segments"]); - auto const& categories_sizes = get(in["categories_sizes"]); - auto const& categories_nodes = get(in["categories_nodes"]); - auto const& categories = get(in["categories"]); - - size_t cnt = 0; + auto const& categories_segments = get>(in["categories_segments"]); + auto const& categories_sizes = get>(in["categories_sizes"]); + auto const& categories_nodes = get>(in["categories_nodes"]); + auto const& categories = get>(in["categories"]); + + auto split_type = get>(in["split_type"]); + bst_node_t n_nodes = split_type.size(); + std::size_t cnt = 0; bst_node_t last_cat_node = -1; if (!categories_nodes.empty()) { last_cat_node = GetElem(categories_nodes, cnt); @@ -938,7 +934,10 @@ void RegTree::LoadCategoricalSplit(Json const& in) { // `categories_segments' is only available for categorical nodes to prevent overhead for // numerical node. As a result, we need to track the categorical nodes we have processed // so far. - for (bst_node_t nidx = 0; nidx < param.num_nodes; ++nidx) { + split_types_.resize(n_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(n_nodes); + for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) { + split_types_[nidx] = static_cast(GetElem(split_type, nidx)); if (nidx == last_cat_node) { auto j_begin = GetElem(categories_segments, cnt); auto j_end = GetElem(categories_sizes, cnt) + j_begin; @@ -985,15 +984,17 @@ template void RegTree::LoadCategoricalSplit(Json const& in); void RegTree::SaveCategoricalSplit(Json* p_out) const { auto& out = *p_out; - CHECK_EQ(this->split_types_.size(), param.num_nodes); - CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); + CHECK_EQ(this->split_types_.size(), this->Size()); + CHECK_EQ(this->GetSplitCategoriesPtr().size(), this->Size()); I64Array categories_segments; I64Array categories_sizes; I32Array categories; // bst_cat_t = int32_t I32Array categories_nodes; // bst_note_t = int32_t + U8Array split_type(split_types_.size()); for (size_t i = 0; i < nodes_.size(); ++i) { + split_type.Set(i, static_cast>(this->NodeSplitType(i))); if (this->split_types_[i] == FeatureType::kCategorical) { categories_nodes.GetArray().emplace_back(i); auto begin = categories.Size(); @@ -1012,66 +1013,49 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const { } } + out["split_type"] = std::move(split_type); out["categories_segments"] = std::move(categories_segments); out["categories_sizes"] = std::move(categories_sizes); out["categories_nodes"] = std::move(categories_nodes); out["categories"] = std::move(categories); } -template , - typename U8ArrayT = std::conditional_t, - typename I32ArrayT = std::conditional_t, - typename I64ArrayT = std::conditional_t, - typename IndexArrayT = std::conditional_t> -bool LoadModelImpl(Json const& in, TreeParam* param, std::vector* p_stats, - std::vector* p_split_types, std::vector* p_nodes, - std::vector* p_split_categories_segments) { +template +void LoadModelImpl(Json const& in, TreeParam const& param, std::vector* p_stats, + std::vector* p_nodes) { + namespace tf = tree_field; auto& stats = *p_stats; - auto& split_types = *p_split_types; auto& nodes = *p_nodes; - auto& split_categories_segments = *p_split_categories_segments; - FromJson(in["tree_param"], param); - auto n_nodes = param->num_nodes; + auto n_nodes = param.num_nodes; CHECK_NE(n_nodes, 0); // stats - auto const& loss_changes = get(in["loss_changes"]); + auto const& loss_changes = get>(in[tf::kLossChg]); CHECK_EQ(loss_changes.size(), n_nodes); - auto const& sum_hessian = get(in["sum_hessian"]); + auto const& sum_hessian = get>(in[tf::kSumHess]); CHECK_EQ(sum_hessian.size(), n_nodes); - auto const& base_weights = get(in["base_weights"]); + auto const& base_weights = get>(in[tf::kBaseWeight]); CHECK_EQ(base_weights.size(), n_nodes); // nodes - auto const& lefts = get(in["left_children"]); + auto const& lefts = get>(in[tf::kLeft]); CHECK_EQ(lefts.size(), n_nodes); - auto const& rights = get(in["right_children"]); + auto const& rights = get>(in[tf::kRight]); CHECK_EQ(rights.size(), n_nodes); - auto const& parents = get(in["parents"]); + auto const& parents = get>(in[tf::kParent]); CHECK_EQ(parents.size(), n_nodes); - auto const& indices = get(in["split_indices"]); + auto const& indices = get>(in[tf::kSplitIdx]); CHECK_EQ(indices.size(), n_nodes); - auto const& conds = get(in["split_conditions"]); + auto const& conds = get>(in[tf::kSplitCond]); CHECK_EQ(conds.size(), n_nodes); - auto const& default_left = get(in["default_left"]); + auto const& default_left = get>(in[tf::kDftLeft]); CHECK_EQ(default_left.size(), n_nodes); - bool has_cat = get(in).find("split_type") != get(in).cend(); - std::remove_const_t(in["split_type"]))>> - split_type; - if (has_cat) { - split_type = get(in["split_type"]); - } - // Initialization stats = std::remove_reference_t(n_nodes); nodes = std::remove_reference_t(n_nodes); - split_types = std::remove_reference_t(n_nodes); - split_categories_segments = std::remove_reference_t(n_nodes); static_assert(std::is_integral(lefts, 0))>::value); static_assert(std::is_floating_point(loss_changes, 0))>::value); - CHECK_EQ(n_nodes, split_categories_segments.size()); // Set node for (int32_t i = 0; i < n_nodes; ++i) { @@ -1088,41 +1072,46 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector* float cond{GetElem(conds, i)}; bool dft_left{GetElem(default_left, i)}; n = RegTree::Node{left, right, parent, ind, cond, dft_left}; - - if (has_cat) { - split_types[i] = static_cast(GetElem(split_type, i)); - } } - - return has_cat; } void RegTree::LoadModel(Json const& in) { - bool has_cat{false}; - bool typed = IsA(in["loss_changes"]); - bool feature_is_64 = IsA(in["split_indices"]); - if (typed && feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else if (typed && !feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else if (!typed && feature_is_64) { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } else { - has_cat = LoadModelImpl(in, ¶m, &stats_, &split_types_, &nodes_, - &split_categories_segments_); - } - + namespace tf = tree_field; + + bool typed = IsA(in[tf::kParent]); + auto const& in_obj = get(in); + // basic properties + FromJson(in["tree_param"], ¶m); + // categorical splits + bool has_cat = in_obj.find("split_type") != in_obj.cend(); if (has_cat) { if (typed) { this->LoadCategoricalSplit(in); } else { this->LoadCategoricalSplit(in); } + } + // multi-target + if (param.size_leaf_vector > 1) { + this->p_mt_tree_.reset(new MultiTargetTree{¶m}); + this->GetMultiTargetTree()->LoadModel(in); + return; + } + + bool feature_is_64 = IsA(in["split_indices"]); + if (typed && feature_is_64) { + LoadModelImpl(in, param, &stats_, &nodes_); + } else if (typed && !feature_is_64) { + LoadModelImpl(in, param, &stats_, &nodes_); + } else if (!typed && feature_is_64) { + LoadModelImpl(in, param, &stats_, &nodes_); } else { + LoadModelImpl(in, param, &stats_, &nodes_); + } + + if (!has_cat) { this->split_categories_segments_.resize(this->param.num_nodes); + this->split_types_.resize(this->param.num_nodes); std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); } @@ -1144,16 +1133,26 @@ void RegTree::LoadModel(Json const& in) { } void RegTree::SaveModel(Json* p_out) const { + auto& out = *p_out; + // basic properties + out["tree_param"] = ToJson(param); + // categorical splits + this->SaveCategoricalSplit(p_out); + // multi-target + if (this->IsMultiTarget()) { + CHECK_GT(param.size_leaf_vector, 1); + this->GetMultiTargetTree()->SaveModel(p_out); + return; + } /* Here we are treating leaf node and internal node equally. Some information like * child node id doesn't make sense for leaf node but we will have to save them to * avoid creating a huge map. One difficulty is XGBoost has deleted node created by * pruner, and this pruner can be used inside another updater so leaf are not necessary * at the end of node array. */ - auto& out = *p_out; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); - out["tree_param"] = ToJson(param); + CHECK_EQ(get(out["tree_param"]["num_nodes"]), std::to_string(param.num_nodes)); auto n_nodes = param.num_nodes; @@ -1167,12 +1166,12 @@ void RegTree::SaveModel(Json* p_out) const { I32Array rights(n_nodes); I32Array parents(n_nodes); - F32Array conds(n_nodes); U8Array default_left(n_nodes); - U8Array split_type(n_nodes); CHECK_EQ(this->split_types_.size(), param.num_nodes); + namespace tf = tree_field; + auto save_tree = [&](auto* p_indices_array) { auto& indices_array = *p_indices_array; for (bst_node_t i = 0; i < n_nodes; ++i) { @@ -1188,33 +1187,28 @@ void RegTree::SaveModel(Json* p_out) const { indices_array.Set(i, n.SplitIndex()); conds.Set(i, n.SplitCond()); default_left.Set(i, static_cast(!!n.DefaultLeft())); - - split_type.Set(i, static_cast(this->NodeSplitType(i))); } }; if (this->param.num_feature > static_cast(std::numeric_limits::max())) { I64Array indices_64(n_nodes); save_tree(&indices_64); - out["split_indices"] = std::move(indices_64); + out[tf::kSplitIdx] = std::move(indices_64); } else { I32Array indices_32(n_nodes); save_tree(&indices_32); - out["split_indices"] = std::move(indices_32); + out[tf::kSplitIdx] = std::move(indices_32); } - this->SaveCategoricalSplit(&out); - - out["split_type"] = std::move(split_type); - out["loss_changes"] = std::move(loss_changes); - out["sum_hessian"] = std::move(sum_hessian); - out["base_weights"] = std::move(base_weights); + out[tf::kLossChg] = std::move(loss_changes); + out[tf::kSumHess] = std::move(sum_hessian); + out[tf::kBaseWeight] = std::move(base_weights); - out["left_children"] = std::move(lefts); - out["right_children"] = std::move(rights); - out["parents"] = std::move(parents); + out[tf::kLeft] = std::move(lefts); + out[tf::kRight] = std::move(rights); + out[tf::kParent] = std::move(parents); - out["split_conditions"] = std::move(conds); - out["default_left"] = std::move(default_left); + out[tf::kSplitCond] = std::move(conds); + out[tf::kDftLeft] = std::move(default_left); } void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 32b3f4a03d23..607aa8dc4b4b 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -445,7 +445,7 @@ struct GPUHistMakerDevice { dh::caching_device_vector d_split_types; dh::caching_device_vector d_categories; - dh::caching_device_vector d_categories_segments; + dh::caching_device_vector d_categories_segments; if (!categories.empty()) { dh::CopyToD(h_split_types, &d_split_types); @@ -458,12 +458,11 @@ struct GPUHistMakerDevice { p_out_position); } - void FinalisePositionInPage(EllpackPageImpl const *page, - const common::Span d_nodes, - common::Span d_feature_types, - common::Span categories, - common::Span categories_segments, - HostDeviceVector* p_out_position) { + void FinalisePositionInPage( + EllpackPageImpl const* page, const common::Span d_nodes, + common::Span d_feature_types, common::Span categories, + common::Span categories_segments, + HostDeviceVector* p_out_position) { auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_gpair = this->gpair; p_out_position->SetDevice(ctx_->gpu_id); diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc new file mode 100644 index 000000000000..7d2bd9c7cdd2 --- /dev/null +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Context +#include +#include // for RegTree + +namespace xgboost { +TEST(MultiTargetTree, JsonIO) { + bst_target_t n_targets{3}; + bst_feature_t n_features{4}; + RegTree tree{n_targets, n_features}; + ASSERT_TRUE(tree.IsMultiTarget()); + linalg::Vector base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, Context::kCpuId}; + linalg::Vector left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, Context::kCpuId}; + linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId}; + tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), + left_weight.HostView(), right_weight.HostView()); + ASSERT_EQ(tree.param.num_nodes, 3); + ASSERT_EQ(tree.param.size_leaf_vector, 3); + ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); + ASSERT_EQ(tree.Size(), 3); + + Json jtree{Object{}}; + tree.SaveModel(&jtree); + + auto check_jtree = [](Json jtree, RegTree const& tree) { + ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), + std::to_string(tree.param.num_nodes)); + ASSERT_EQ(get(jtree["base_weights"]).size(), + tree.param.num_nodes * tree.param.size_leaf_vector); + ASSERT_EQ(get(jtree["parents"]).size(), tree.param.num_nodes); + ASSERT_EQ(get(jtree["left_children"]).size(), tree.param.num_nodes); + ASSERT_EQ(get(jtree["right_children"]).size(), tree.param.num_nodes); + }; + check_jtree(jtree, tree); + + RegTree loaded; + loaded.LoadModel(jtree); + ASSERT_TRUE(loaded.IsMultiTarget()); + ASSERT_EQ(loaded.param.num_nodes, 3); + + Json jtree1{Object{}}; + loaded.SaveModel(&jtree1); + check_jtree(jtree1, tree); +} +} // namespace xgboost diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 65957255bf38..130a0ef7082f 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -477,7 +477,7 @@ TEST(Tree, JsonIO) { auto tparam = j_tree["tree_param"]; ASSERT_EQ(get(tparam["num_feature"]), "0"); ASSERT_EQ(get(tparam["num_nodes"]), "3"); - ASSERT_EQ(get(tparam["size_leaf_vector"]), "0"); + ASSERT_EQ(get(tparam["size_leaf_vector"]), "1"); ASSERT_EQ(get(j_tree["left_children"]).size(), 3ul); ASSERT_EQ(get(j_tree["right_children"]).size(), 3ul);