Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define core multi-target regression tree structure. #8884

Merged
merged 5 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/fit_stump.o \
$(PKGROOT)/src/tree/tree_model.o \
$(PKGROOT)/src/tree/tree_updater.o \
$(PKGROOT)/src/tree/multi_target_tree_model.o \
$(PKGROOT)/src/tree/updater_approx.o \
$(PKGROOT)/src/tree/updater_colmaker.o \
$(PKGROOT)/src/tree/updater_prune.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/param.o \
$(PKGROOT)/src/tree/fit_stump.o \
$(PKGROOT)/src/tree/tree_model.o \
$(PKGROOT)/src/tree/multi_target_tree_model.o \
$(PKGROOT)/src/tree/tree_updater.o \
$(PKGROOT)/src/tree/updater_approx.o \
$(PKGROOT)/src/tree/updater_colmaker.o \
Expand Down
8 changes: 4 additions & 4 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ using bst_bin_t = int32_t; // NOLINT
*/
using bst_row_t = std::size_t; // NOLINT
/*! \brief Type for tree node index. */
using bst_node_t = int32_t; // NOLINT
using bst_node_t = std::int32_t; // NOLINT
/*! \brief Type for ranking group index. */
using bst_group_t = uint32_t; // NOLINT
/*! \brief Type for indexing target variables. */
using bst_target_t = std::size_t; // NOLINT
using bst_group_t = std::uint32_t; // NOLINT
/*! \brief Type for indexing into output targets. */
using bst_target_t = std::uint32_t; // NOLINT

namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation
Expand Down
96 changes: 96 additions & 0 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/**
* Copyright 2023 by XGBoost contributors
*
* \brief Core data structure for multi-target trees.
*/
#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Model
#include <xgboost/span.h> // for Span

#include <cinttypes> // for uint8_t
#include <cstddef> // for size_t
#include <vector> // for vector

namespace xgboost {
struct TreeParam;
/**
* \brief Tree structure for multi-target model.
*/
class MultiTargetTree : public Model {
public:
static bst_node_t constexpr InvalidNodeId() { return -1; }

private:
TreeParam const* param_;
std::vector<bst_node_t> left_;
std::vector<bst_node_t> right_;
std::vector<bst_node_t> parent_;
std::vector<bst_feature_t> split_index_;
std::vector<std::uint8_t> default_left_;
std::vector<float> split_conds_;
std::vector<float> weights_;

[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
}
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
return linalg::MakeTensorView(Context::kCpuId, v, v.size());
}

public:
explicit MultiTargetTree(TreeParam const* param);
/**
* \brief Set the weight for a leaf.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
/**
* \brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);

[[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }

[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
[[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
}

[[nodiscard]] bst_target_t NumTarget() const;

[[nodiscard]] std::size_t Size() const;

[[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
bst_node_t depth{0};
while (Parent(nidx) != InvalidNodeId()) {
++depth;
nidx = Parent(nidx);
}
return depth;
}

[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
CHECK(IsLeaf(nidx));
return this->NodeWeight(nidx);
}

void LoadModel(Json const& in) override;
void SaveModel(Json* out) const override;
};
} // namespace xgboost
#endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
Loading