diff --git a/CMakeLists.txt b/CMakeLists.txt index 78d5555ee169..005ff3cc98b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ include_directories ( file(GLOB_RECURSE SOURCES src/*.cc src/*.h + include/*.h ) # Only add main function for executable target list(REMOVE_ITEM SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc) diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index efa43fd00c01..88f780cbf525 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -32,6 +32,10 @@ #include "../src/data/simple_dmatrix.cc" #include "../src/data/sparse_page_raw_format.cc" +// prediction +#include "../src/predictor/predictor.cc" +#include "../src/predictor/cpu_predictor.cc" + #if DMLC_ENABLE_STD_THREAD #include "../src/data/sparse_page_source.cc" #include "../src/data/sparse_page_dmatrix.cc" diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 22f99a7a53fb..d5e86677ce8e 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -77,7 +77,7 @@ class GradientBooster { * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear */ - virtual void Predict(DMatrix* dmat, + virtual void PredictBatch(DMatrix* dmat, std::vector* out_preds, unsigned ntree_limit = 0) = 0; /*! @@ -92,7 +92,7 @@ class GradientBooster { * \param root_index the root index * \sa Predict */ - virtual void Predict(const SparseBatch::Inst& inst, + virtual void PredictInstance(const SparseBatch::Inst& inst, std::vector* out_preds, unsigned ntree_limit = 0, unsigned root_index = 0) = 0; diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index ab41d337b026..bcd200bf39b1 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -189,7 +189,7 @@ inline void Learner::Predict(const SparseBatch::Inst& inst, bool output_margin, std::vector* out_preds, unsigned ntree_limit) const { - gbm_->Predict(inst, out_preds, ntree_limit); + gbm_->PredictInstance(inst, out_preds, ntree_limit); if (!output_margin) { obj_->PredTransform(out_preds); } diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h new file mode 100644 index 000000000000..bc37f66d74e9 --- /dev/null +++ b/include/xgboost/predictor.h @@ -0,0 +1,172 @@ +/*! + * Copyright by Contributors + * \file predictor.h + * \brief Interface of predictor, + * performs predictions for a gradient booster. + */ +#pragma once +#include +#include +#include +#include +#include +#include "../../src/gbm/gbtree_model.h" + +// Forward declarations +namespace xgboost { +class DMatrix; +class TreeUpdater; +} +namespace xgboost { +namespace gbm { +struct GBTreeModel; +} +} // namespace xgboost + +namespace xgboost { + +/** + * \class Predictor + * + * \brief Performs prediction on individual training instances or batches of instances for GBTree. + * The predictor also manages a prediction cache associated with input matrices. If possible, + * it will use previously calculated predictions instead of calculating new predictions. + * Prediction functions all take a GBTreeModel and a DMatrix as input and output a vector of + * predictions. The predictor does not modify any state of the model itself. + */ + +class Predictor { + public: + virtual ~Predictor() {} + + /** + * \fn void Predictor::InitCache(const std::vector > &cache); + * + * \brief Register input matrices in prediction cache. + * + * \param cache Vector of DMatrix's to be used in prediction. + */ + + void InitCache(const std::vector > &cache); + + /** + * \fn virtual void Predictor::PredictBatch( DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel &model, int tree_begin, unsigned ntree_limit = 0) = 0; + * + * \brief Generate batch predictions for a given feature matrix. May use cached predictions if available instead of calculating from scratch. + * + * \param [in,out] dmat Feature matrix. + * \param [in,out] out_preds The output preds. + * \param model The model to predict from. + * \param tree_begin The tree begin index. + * \param ntree_limit (Optional) The ntree limit. 0 means do not limit trees. + */ + + virtual void PredictBatch( + DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel &model, + int tree_begin, unsigned ntree_limit = 0) = 0; + + /** + * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel &model, std::vector >* updaters, int num_new_trees) = 0; + * + * \brief Update the internal prediction cache using newly added trees. Will use the tree updater + * to do this if possible. Should be called as a part of the tree boosting process to facilitate the look up of predictions at a later time. + * + * \param model The model. + * \param [in,out] updaters The updater sequence for gradient boosting. + * \param num_new_trees Number of new trees. + */ + + virtual void UpdatePredictionCache( + const gbm::GBTreeModel &model, std::vector >* updaters, + int num_new_trees) = 0; + + /** + * \fn virtual void Predictor::PredictInstance( const SparseBatch::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0; + * + * \brief online prediction function, predict score for one instance at a time NOTE: use the batch + * prediction interface if possible, batch prediction is usually more efficient than online + * prediction This function is NOT threadsafe, make sure you only call from one thread. + * + * \param inst The instance to predict. + * \param [in,out] out_preds The output preds. + * \param model The model to predict from + * \param ntree_limit (Optional) The ntree limit. + * \param root_index (Optional) Zero-based index of the root. + */ + + virtual void PredictInstance( + const SparseBatch::Inst& inst, std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit = 0, unsigned root_index = 0) = 0; + + /** + * \fn virtual void Predictor::PredictLeaf(DMatrix* dmat, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + * + * \brief predict the leaf index of each tree, the output will be nsample * ntree vector this is + * only valid in gbtree predictor. + * + * \param [in,out] dmat The input feature matrix. + * \param [in,out] out_preds The output preds. + * \param model Model to make predictions from. + * \param ntree_limit (Optional) The ntree limit. + */ + + virtual void PredictLeaf(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + + /** + * \fn virtual void Predictor::PredictContribution( DMatrix* dmat, std::vector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + * + * \brief feature contributions to individual predictions; the output will be a vector of length + * (nfeats + 1) * num_output_group * nsample, arranged in that order. + * + * \param [in,out] dmat The input feature matrix. + * \param [in,out] out_contribs The output feature contribs. + * \param model Model to make predictions from. + * \param ntree_limit (Optional) The ntree limit. + */ + + virtual void PredictContribution( + DMatrix* dmat, std::vector* out_contribs, + const gbm::GBTreeModel& model, unsigned ntree_limit = 0) = 0; + + /** + * \fn static Predictor* Predictor::Create(std::string name); + * + * \brief Creates a new Predictor*. + * + */ + + static Predictor* Create(std::string name); + + protected: + /** + * \struct PredictionCacheEntry + * + * \brief Contains pointer to input matrix and associated cached predictions. + */ + + struct PredictionCacheEntry { + std::shared_ptr data; + std::vector predictions; + }; + + /** + * \brief Map of matrices and associated cached predictions to facilitate storing and looking up + * predictions. + */ + + std::unordered_map cache_; +}; + +/*! + * \brief Registry entry for predictor. + */ +struct PredictorReg + : public dmlc::FunctionRegEntryBase> {}; + +#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ + __make_##PredictorReg##_##UniqueId##__ = \ + ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) +} // namespace xgboost diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 9aa661f6b272..5c1f2474a48b 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -170,7 +170,7 @@ class GBLinear : public GradientBooster { } } - void Predict(DMatrix *p_fmat, + void PredictBatch(DMatrix *p_fmat, std::vector *out_preds, unsigned ntree_limit) override { if (model.weight.size() == 0) { @@ -205,7 +205,7 @@ class GBLinear : public GradientBooster { } } // add base margin - void Predict(const SparseBatch::Inst &inst, + void PredictInstance(const SparseBatch::Inst &inst, std::vector *out_preds, unsigned ntree_limit, unsigned root_index) override { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index ed15933336be..c5ca3aeb45ec 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -9,18 +9,17 @@ #include #include #include +#include #include - #include #include #include #include #include -#include #include #include "../common/common.h" - #include "../common/random.h" +#include "gbtree_model.h" namespace xgboost { namespace gbm { @@ -121,47 +120,6 @@ struct DartTrainParam : public dmlc::Parameter { } }; -/*! \brief model parameters */ -struct GBTreeModelParam : public dmlc::Parameter { - /*! \brief number of trees */ - int num_trees; - /*! \brief number of roots */ - int num_roots; - /*! \brief number of features to be used by trees */ - int num_feature; - /*! \brief pad this space, for backward compatibility reason.*/ - int pad_32bit; - /*! \brief deprecated padding space. */ - int64_t num_pbuffer_deprecated; - /*! - * \brief how many output group a single instance can produce - * this affects the behavior of number of output we have: - * suppose we have n instance and k group, output will be k * n - */ - int num_output_group; - /*! \brief size of leaf vector needed in tree */ - int size_leaf_vector; - /*! \brief reserved parameters */ - int reserved[32]; - /*! \brief constructor */ - GBTreeModelParam() { - std::memset(this, 0, sizeof(GBTreeModelParam)); - static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int), - "64/32 bit compatibility issue"); - } - // declare parameters, only declare those that need to be set. - DMLC_DECLARE_PARAMETER(GBTreeModelParam) { - DMLC_DECLARE_FIELD(num_output_group).set_lower_bound(1).set_default(1) - .describe("Number of output groups to be predicted,"\ - " used for multi-class classification."); - DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1) - .describe("Tree updater sequence."); - DMLC_DECLARE_FIELD(num_feature).set_lower_bound(0) - .describe("Number of features used for training and prediction."); - DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0) - .describe("Reserved option for vector tree."); - } -}; // cache entry struct CacheEntry { @@ -172,22 +130,18 @@ struct CacheEntry { // gradient boosted trees class GBTree : public GradientBooster { public: - explicit GBTree(bst_float base_margin) : base_margin_(base_margin) {} + explicit GBTree(bst_float base_margin) + : model_(base_margin), + predictor( + std::unique_ptr(Predictor::Create("cpu_predictor"))) {} void InitCache(const std::vector > &cache) { - for (const std::shared_ptr& d : cache) { - CacheEntry e; - e.data = d; - cache_[d.get()] = std::move(e); - } + predictor->InitCache(cache); } void Configure(const std::vector >& cfg) override { this->cfg = cfg; - // initialize model parameters if not yet been initialized. - if (trees.size() == 0) { - mparam.InitAllowUnknown(cfg); - } + model_.Configure(cfg); // initialize the updaters only when needed. std::string updater_seq = tparam.updater_seq; tparam.InitAllowUnknown(cfg); @@ -196,48 +150,25 @@ class GBTree : public GradientBooster { up->Init(cfg); } // for the 'update' process_type, move trees into trees_to_update - if (tparam.process_type == kUpdate && trees_to_update.size() == 0u) { - for (size_t i = 0; i < trees.size(); ++i) { - trees_to_update.push_back(std::move(trees[i])); - } - trees.clear(); - mparam.num_trees = 0; + if (tparam.process_type == kUpdate) { + model_.InitTreesToUpdate(); } } void Load(dmlc::Stream* fi) override { - CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam)) - << "GBTree: invalid model file"; - trees.clear(); - trees_to_update.clear(); - for (int i = 0; i < mparam.num_trees; ++i) { - std::unique_ptr ptr(new RegTree()); - ptr->Load(fi); - trees.push_back(std::move(ptr)); - } - tree_info.resize(mparam.num_trees); - if (mparam.num_trees != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees), - sizeof(int) * mparam.num_trees); - } + model_.Load(fi); + this->cfg.clear(); this->cfg.push_back(std::make_pair(std::string("num_feature"), - common::ToString(mparam.num_feature))); + common::ToString(model_.param.num_feature))); } void Save(dmlc::Stream* fo) const override { - CHECK_EQ(mparam.num_trees, static_cast(trees.size())); - fo->Write(&mparam, sizeof(mparam)); - for (size_t i = 0; i < trees.size(); ++i) { - trees[i]->Save(fo); - } - if (tree_info.size() != 0) { - fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size()); - } + model_.Save(fo); } bool AllowLazyCheckPoint() const override { - return mparam.num_output_group == 1 || + return model_.param.num_output_group == 1 || tparam.updater_seq.find("distcol") != std::string::npos; } @@ -246,7 +177,7 @@ class GBTree : public GradientBooster { ObjFunction* obj) override { const std::vector& gpair = *in_gpair; std::vector > > new_trees; - const int ngroup = mparam.num_output_group; + const int ngroup = model_.param.num_output_group; if (ngroup == 1) { std::vector > ret; BoostNewTrees(gpair, p_fmat, 0, &ret); @@ -275,167 +206,39 @@ class GBTree : public GradientBooster { } } - void Predict(DMatrix* p_fmat, + void PredictBatch(DMatrix* p_fmat, std::vector* out_preds, unsigned ntree_limit) override { - if (ntree_limit == 0 || - ntree_limit * mparam.num_output_group >= trees.size()) { - auto it = cache_.find(p_fmat); - if (it != cache_.end()) { - std::vector& y = it->second.predictions; - if (y.size() != 0) { - out_preds->resize(y.size()); - std::copy(y.begin(), y.end(), out_preds->begin()); - return; - } - } - } - PredLoopInternal(p_fmat, out_preds, 0, ntree_limit, true); + predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); } - void Predict(const SparseBatch::Inst& inst, + void PredictInstance(const SparseBatch::Inst& inst, std::vector* out_preds, unsigned ntree_limit, unsigned root_index) override { - if (thread_temp.size() == 0) { - thread_temp.resize(1, RegTree::FVec()); - thread_temp[0].Init(mparam.num_feature); - } - ntree_limit *= mparam.num_output_group; - if (ntree_limit == 0 || ntree_limit > trees.size()) { - ntree_limit = static_cast(trees.size()); - } - out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1)); - // loop over output groups - for (int gid = 0; gid < mparam.num_output_group; ++gid) { - (*out_preds)[gid] = - PredValue(inst, gid, root_index, - &thread_temp[0], 0, ntree_limit) + base_margin_; - } + predictor->PredictInstance(inst, out_preds, model_, + ntree_limit, root_index); } void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, unsigned ntree_limit) override { - const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread); - this->PredPath(p_fmat, out_preds, ntree_limit); + predictor->PredictLeaf(p_fmat, out_preds, model_, ntree_limit); } void PredictContribution(DMatrix* p_fmat, std::vector* out_contribs, unsigned ntree_limit) override { - const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread); - this->PredContrib(p_fmat, out_contribs, ntree_limit); + predictor->PredictContribution(p_fmat, out_contribs, model_, ntree_limit); } std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const override { - std::vector dump; - for (size_t i = 0; i < trees.size(); i++) { - dump.push_back(trees[i]->DumpModel(fmap, with_stats, format)); - } - return dump; + return model_.DumpModel(fmap, with_stats, format); } protected: - // internal prediction loop - // add predictions to out_preds - template - inline void PredLoopInternal( - DMatrix* p_fmat, - std::vector* out_preds, - unsigned tree_begin, - unsigned ntree_limit, - bool init_out_preds) { - int num_group = mparam.num_output_group; - ntree_limit *= num_group; - if (ntree_limit == 0 || ntree_limit > trees.size()) { - ntree_limit = static_cast(trees.size()); - } - - if (init_out_preds) { - size_t n = num_group * p_fmat->info().num_row; - const std::vector& base_margin = p_fmat->info().base_margin; - out_preds->resize(n); - if (base_margin.size() != 0) { - CHECK_EQ(out_preds->size(), n); - std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); - } else { - std::fill(out_preds->begin(), out_preds->end(), base_margin_); - } - } - - if (num_group == 1) { - PredLoopSpecalize(p_fmat, out_preds, 1, - tree_begin, ntree_limit); - } else { - PredLoopSpecalize(p_fmat, out_preds, num_group, - tree_begin, ntree_limit); - } - } - - template - inline void PredLoopSpecalize( - DMatrix* p_fmat, - std::vector* out_preds, - int num_group, - unsigned tree_begin, - unsigned tree_end) { - const MetaInfo& info = p_fmat->info(); - const int nthread = omp_get_max_threads(); - CHECK_EQ(num_group, mparam.num_output_group); - InitThreadTemp(nthread); - std::vector& preds = *out_preds; - CHECK_EQ(mparam.size_leaf_vector, 0) - << "size_leaf_vector is enforced to 0 so far"; - CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group); - // start collecting the prediction - dmlc::DataIter* iter = p_fmat->RowIterator(); - Derived* self = static_cast(this); - iter->BeforeFirst(); - while (iter->Next()) { - const RowBatch &batch = iter->Value(); - // parallel over local batch - const int K = 8; - const bst_omp_uint nsize = static_cast(batch.size); - const bst_omp_uint rest = nsize % K; - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize - rest; i += K) { - const int tid = omp_get_thread_num(); - RegTree::FVec& feats = thread_temp[tid]; - int64_t ridx[K]; - RowBatch::Inst inst[K]; - for (int k = 0; k < K; ++k) { - ridx[k] = static_cast(batch.base_rowid + i + k); - } - for (int k = 0; k < K; ++k) { - inst[k] = batch[i + k]; - } - for (int k = 0; k < K; ++k) { - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx[k] * num_group + gid; - preds[offset] += - self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), - &feats, tree_begin, tree_end); - } - } - } - for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { - RegTree::FVec& feats = thread_temp[0]; - const int64_t ridx = static_cast(batch.base_rowid + i); - const RowBatch::Inst inst = batch[i]; - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx * num_group + gid; - preds[offset] += - self->PredValue(inst, gid, info.GetRoot(ridx), - &feats, tree_begin, tree_end); - } - } - } - } // initialize updater before using them inline void InitUpdater() { if (updaters.size() != 0) return; @@ -466,9 +269,9 @@ class GBTree : public GradientBooster { new_trees.push_back(ptr.get()); ret->push_back(std::move(ptr)); } else if (tparam.process_type == kUpdate) { - CHECK_LT(trees.size(), trees_to_update.size()); + CHECK_LT(model_.trees.size(), model_.trees_to_update.size()); // move an existing tree from trees_to_update - auto t = std::move(trees_to_update[trees.size() + + auto t = std::move(model_.trees_to_update[model_.trees.size() + bst_group * tparam.num_parallel_tree + i]); new_trees.push_back(t.get()); ret->push_back(std::move(t)); @@ -483,173 +286,22 @@ class GBTree : public GradientBooster { virtual void CommitModel(std::vector >&& new_trees, int bst_group) { - size_t old_ntree = trees.size(); - for (size_t i = 0; i < new_trees.size(); ++i) { - trees.push_back(std::move(new_trees[i])); - tree_info.push_back(bst_group); - } - mparam.num_trees += static_cast(new_trees.size()); + model_.CommitModel(std::move(new_trees), bst_group); - // update cache entry - for (auto &kv : cache_) { - CacheEntry& e = kv.second; - - if (e.predictions.size() == 0) { - PredLoopInternal( - e.data.get(), &(e.predictions), - 0, trees.size(), true); - } else { - if (mparam.num_output_group == 1 && updaters.size() > 0 && new_trees.size() == 1 - && updaters.back()->UpdatePredictionCache(e.data.get(), &(e.predictions)) ) { - {} // do nothing - } else { - PredLoopInternal( - e.data.get(), &(e.predictions), - old_ntree, trees.size(), false); - } - } - } + predictor->UpdatePredictionCache(model_, &updaters, new_trees.size()); } - // make a prediction for a single instance - inline bst_float PredValue(const RowBatch::Inst &inst, - int bst_group, - unsigned root_index, - RegTree::FVec *p_feats, - unsigned tree_begin, - unsigned tree_end) { - bst_float psum = 0.0f; - p_feats->Fill(inst); - for (size_t i = tree_begin; i < tree_end; ++i) { - if (tree_info[i] == bst_group) { - int tid = trees[i]->GetLeafIndex(*p_feats, root_index); - psum += (*trees[i])[tid].leaf_value(); - } - } - p_feats->Drop(inst); - return psum; - } - // predict independent leaf index - inline void PredPath(DMatrix *p_fmat, - std::vector *out_preds, - unsigned ntree_limit) { - const MetaInfo& info = p_fmat->info(); - // number of valid trees - ntree_limit *= mparam.num_output_group; - if (ntree_limit == 0 || ntree_limit > trees.size()) { - ntree_limit = static_cast(trees.size()); - } - std::vector& preds = *out_preds; - preds.resize(info.num_row * ntree_limit); - // start collecting the prediction - dmlc::DataIter* iter = p_fmat->RowIterator(); - iter->BeforeFirst(); - while (iter->Next()) { - const RowBatch& batch = iter->Value(); - // parallel over local batch - const bst_omp_uint nsize = static_cast(batch.size); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize; ++i) { - const int tid = omp_get_thread_num(); - size_t ridx = static_cast(batch.base_rowid + i); - RegTree::FVec &feats = thread_temp[tid]; - feats.Fill(batch[i]); - for (unsigned j = 0; j < ntree_limit; ++j) { - int tid = trees[j]->GetLeafIndex(feats, info.GetRoot(ridx)); - preds[ridx * ntree_limit + j] = static_cast(tid); - } - feats.Drop(batch[i]); - } - } - } - // predict contributions - inline void PredContrib(DMatrix *p_fmat, - std::vector *out_contribs, - unsigned ntree_limit) { - const MetaInfo& info = p_fmat->info(); - // number of valid trees - ntree_limit *= mparam.num_output_group; - if (ntree_limit == 0 || ntree_limit > trees.size()) { - ntree_limit = static_cast(trees.size()); - } - const int ngroup = mparam.num_output_group; - size_t ncolumns = mparam.num_feature + 1; - // allocate space for (number of features + bias) times the number of rows - std::vector& contribs = *out_contribs; - contribs.resize(info.num_row * ncolumns * mparam.num_output_group); - // make sure contributions is zeroed, we could be reusing a previously allocated one - std::fill(contribs.begin(), contribs.end(), 0); - // initialize tree node mean values - #pragma omp parallel for schedule(static) - for (bst_omp_uint i=0; i < ntree_limit; ++i) { - trees[i]->FillNodeMeanValues(); - } - // start collecting the contributions - dmlc::DataIter* iter = p_fmat->RowIterator(); - const std::vector& base_margin = info.base_margin; - iter->BeforeFirst(); - while (iter->Next()) { - const RowBatch& batch = iter->Value(); - // parallel over local batch - const bst_omp_uint nsize = static_cast(batch.size); - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize; ++i) { - size_t row_idx = static_cast(batch.base_rowid + i); - unsigned root_id = info.GetRoot(row_idx); - RegTree::FVec &feats = thread_temp[omp_get_thread_num()]; - // loop over all classes - for (int gid = 0; gid < ngroup; ++gid) { - bst_float *p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns]; - feats.Fill(batch[i]); - // calculate contributions - for (unsigned j = 0; j < ntree_limit; ++j) { - if (tree_info[j] != gid) { - continue; - } - trees[j]->CalculateContributions(feats, root_id, p_contribs); - } - feats.Drop(batch[i]); - // add base margin to BIAS - if (base_margin.size() != 0) { - p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; - } else { - p_contribs[ncolumns - 1] += base_margin_; - } - } - } - } - } - // init thread buffers - inline void InitThreadTemp(int nthread) { - int prev_thread_temp_size = thread_temp.size(); - if (prev_thread_temp_size < nthread) { - thread_temp.resize(nthread, RegTree::FVec()); - for (int i = prev_thread_temp_size; i < nthread; ++i) { - thread_temp[i].Init(mparam.num_feature); - } - } - } // --- data structure --- - // base margin - bst_float base_margin_; + GBTreeModel model_; // training parameter GBTreeTrainParam tparam; - // model parameter - GBTreeModelParam mparam; - /*! \brief vector of trees stored in the model */ - std::vector > trees; - /*! \brief for the update process, a place to keep the initial trees */ - std::vector > trees_to_update; - /*! \brief some information indicator of the tree, reserved */ - std::vector tree_info; // ----training fields---- - std::unordered_map cache_; // configurations for tree std::vector > cfg; - // temporal storage for per thread - std::vector thread_temp; // the updaters that can be applied to each of tree - std::vector > updaters; + std::vector> updaters; + + std::unique_ptr predictor; }; // dart @@ -659,15 +311,15 @@ class Dart : public GBTree { void Configure(const std::vector >& cfg) override { GBTree::Configure(cfg); - if (trees.size() == 0) { + if (model_.trees.size() == 0) { dparam.InitAllowUnknown(cfg); } } void Load(dmlc::Stream* fi) override { GBTree::Load(fi); - weight_drop.resize(mparam.num_trees); - if (mparam.num_trees != 0) { + weight_drop.resize(model_.param.num_trees); + if (model_.param.num_trees != 0) { fi->Read(&weight_drop); } } @@ -680,45 +332,140 @@ class Dart : public GBTree { } // predict the leaf scores with dropout if ntree_limit = 0 - void Predict(DMatrix* p_fmat, + void PredictBatch(DMatrix* p_fmat, std::vector* out_preds, unsigned ntree_limit) override { DropTrees(ntree_limit); PredLoopInternal(p_fmat, out_preds, 0, ntree_limit, true); } - void Predict(const SparseBatch::Inst& inst, + void PredictInstance(const SparseBatch::Inst& inst, std::vector* out_preds, unsigned ntree_limit, unsigned root_index) override { DropTrees(1); if (thread_temp.size() == 0) { thread_temp.resize(1, RegTree::FVec()); - thread_temp[0].Init(mparam.num_feature); + thread_temp[0].Init(model_.param.num_feature); } - out_preds->resize(mparam.num_output_group); - ntree_limit *= mparam.num_output_group; - if (ntree_limit == 0 || ntree_limit > trees.size()) { - ntree_limit = static_cast(trees.size()); + out_preds->resize(model_.param.num_output_group); + ntree_limit *= model_.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { + ntree_limit = static_cast(model_.trees.size()); } // loop over output groups - for (int gid = 0; gid < mparam.num_output_group; ++gid) { + for (int gid = 0; gid < model_.param.num_output_group; ++gid) { (*out_preds)[gid] = PredValue(inst, gid, root_index, - &thread_temp[0], 0, ntree_limit) + base_margin_; + &thread_temp[0], 0, ntree_limit) + model_.base_margin; } } protected: friend class GBTree; + // internal prediction loop + // add predictions to out_preds + template + inline void PredLoopInternal( + DMatrix* p_fmat, + std::vector* out_preds, + unsigned tree_begin, + unsigned ntree_limit, + bool init_out_preds) { + int num_group = model_.param.num_output_group; + ntree_limit *= num_group; + if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { + ntree_limit = static_cast(model_.trees.size()); + } + + if (init_out_preds) { + size_t n = num_group * p_fmat->info().num_row; + const std::vector& base_margin = p_fmat->info().base_margin; + out_preds->resize(n); + if (base_margin.size() != 0) { + CHECK_EQ(out_preds->size(), n); + std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); + } else { + std::fill(out_preds->begin(), out_preds->end(), model_.base_margin); + } + } + + if (num_group == 1) { + PredLoopSpecalize(p_fmat, out_preds, 1, + tree_begin, ntree_limit); + } else { + PredLoopSpecalize(p_fmat, out_preds, num_group, + tree_begin, ntree_limit); + } + } + + template + inline void PredLoopSpecalize( + DMatrix* p_fmat, + std::vector* out_preds, + int num_group, + unsigned tree_begin, + unsigned tree_end) { + const MetaInfo& info = p_fmat->info(); + const int nthread = omp_get_max_threads(); + CHECK_EQ(num_group, model_.param.num_output_group); + InitThreadTemp(nthread); + std::vector& preds = *out_preds; + CHECK_EQ(model_.param.size_leaf_vector, 0) + << "size_leaf_vector is enforced to 0 so far"; + CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group); + // start collecting the prediction + dmlc::DataIter* iter = p_fmat->RowIterator(); + Derived* self = static_cast(this); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + // parallel over local batch + const int K = 8; + const bst_omp_uint nsize = static_cast(batch.size); + const bst_omp_uint rest = nsize % K; + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize - rest; i += K) { + const int tid = omp_get_thread_num(); + RegTree::FVec& feats = thread_temp[tid]; + int64_t ridx[K]; + RowBatch::Inst inst[K]; + for (int k = 0; k < K; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); + } + for (int k = 0; k < K; ++k) { + inst[k] = batch[i + k]; + } + for (int k = 0; k < K; ++k) { + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += + self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), + &feats, tree_begin, tree_end); + } + } + } + for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { + RegTree::FVec& feats = thread_temp[0]; + const int64_t ridx = static_cast(batch.base_rowid + i); + const RowBatch::Inst inst = batch[i]; + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx * num_group + gid; + preds[offset] += + self->PredValue(inst, gid, info.GetRoot(ridx), + &feats, tree_begin, tree_end); + } + } + } + } // commit new trees all at once void CommitModel(std::vector >&& new_trees, int bst_group) override { for (size_t i = 0; i < new_trees.size(); ++i) { - trees.push_back(std::move(new_trees[i])); - tree_info.push_back(bst_group); + model_.trees.push_back(std::move(new_trees[i])); + model_.tree_info.push_back(bst_group); } - mparam.num_trees += static_cast(new_trees.size()); + model_.param.num_trees += static_cast(new_trees.size()); size_t num_drop = NormalizeTrees(new_trees.size()); if (dparam.silent != 1) { LOG(INFO) << "drop " << num_drop << " trees, " @@ -735,11 +482,11 @@ class Dart : public GBTree { bst_float psum = 0.0f; p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { - if (tree_info[i] == bst_group) { + if (model_.tree_info[i] == bst_group) { bool drop = (std::binary_search(idx_drop.begin(), idx_drop.end(), i)); if (!drop) { - int tid = trees[i]->GetLeafIndex(*p_feats, root_index); - psum += weight_drop[i] * (*trees[i])[tid].leaf_value(); + int tid = model_.trees[i]->GetLeafIndex(*p_feats, root_index); + psum += weight_drop[i] * (*model_.trees[i])[tid].leaf_value(); } } } @@ -825,6 +572,17 @@ class Dart : public GBTree { return num_drop; } + // init thread buffers + inline void InitThreadTemp(int nthread) { + int prev_thread_temp_size = thread_temp.size(); + if (prev_thread_temp_size < nthread) { + thread_temp.resize(nthread, RegTree::FVec()); + for (int i = prev_thread_temp_size; i < nthread; ++i) { + thread_temp[i].Init(model_.param.num_feature); + } + } + } + // --- data structure --- // training parameter DartTrainParam dparam; @@ -832,6 +590,8 @@ class Dart : public GBTree { std::vector weight_drop; // indexes of dropped trees std::vector idx_drop; + // temporal storage for per thread + std::vector thread_temp; }; // register the objective functions diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h new file mode 100644 index 000000000000..aa201cce94d1 --- /dev/null +++ b/src/gbm/gbtree_model.h @@ -0,0 +1,140 @@ +/*! + * Copyright by Contributors 2017 + */ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace xgboost { +namespace gbm { +/*! \brief model parameters */ +struct GBTreeModelParam : public dmlc::Parameter { + /*! \brief number of trees */ + int num_trees; + /*! \brief number of roots */ + int num_roots; + /*! \brief number of features to be used by trees */ + int num_feature; + /*! \brief pad this space, for backward compatibility reason.*/ + int pad_32bit; + /*! \brief deprecated padding space. */ + int64_t num_pbuffer_deprecated; + /*! + * \brief how many output group a single instance can produce + * this affects the behavior of number of output we have: + * suppose we have n instance and k group, output will be k * n + */ + int num_output_group; + /*! \brief size of leaf vector needed in tree */ + int size_leaf_vector; + /*! \brief reserved parameters */ + int reserved[32]; + /*! \brief constructor */ + GBTreeModelParam() { + std::memset(this, 0, sizeof(GBTreeModelParam)); + static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int), + "64/32 bit compatibility issue"); + } + // declare parameters, only declare those that need to be set. + DMLC_DECLARE_PARAMETER(GBTreeModelParam) { + DMLC_DECLARE_FIELD(num_output_group) + .set_lower_bound(1) + .set_default(1) + .describe( + "Number of output groups to be predicted," + " used for multi-class classification."); + DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1).describe( + "Tree updater sequence."); + DMLC_DECLARE_FIELD(num_feature) + .set_lower_bound(0) + .describe("Number of features used for training and prediction."); + DMLC_DECLARE_FIELD(size_leaf_vector) + .set_lower_bound(0) + .set_default(0) + .describe("Reserved option for vector tree."); + } +}; + +struct GBTreeModel { + explicit GBTreeModel(bst_float base_margin) : base_margin(base_margin) {} + void Configure(const std::vector >& cfg) { + // initialize model parameters if not yet been initialized. + if (trees.size() == 0) { + param.InitAllowUnknown(cfg); + } + } + + void InitTreesToUpdate() { + if (trees_to_update.size() == 0u) { + for (size_t i = 0; i < trees.size(); ++i) { + trees_to_update.push_back(std::move(trees[i])); + } + trees.clear(); + param.num_trees = 0; + } + } + + void Load(dmlc::Stream* fi) { + CHECK_EQ(fi->Read(¶m, sizeof(param)), sizeof(param)) + << "GBTree: invalid model file"; + trees.clear(); + trees_to_update.clear(); + for (int i = 0; i < param.num_trees; ++i) { + std::unique_ptr ptr(new RegTree()); + ptr->Load(fi); + trees.push_back(std::move(ptr)); + } + tree_info.resize(param.num_trees); + if (param.num_trees != 0) { + CHECK_EQ( + fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * param.num_trees), + sizeof(int) * param.num_trees); + } + } + + void Save(dmlc::Stream* fo) const { + CHECK_EQ(param.num_trees, static_cast(trees.size())); + fo->Write(¶m, sizeof(param)); + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->Save(fo); + } + if (tree_info.size() != 0) { + fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size()); + } + } + + std::vector DumpModel(const FeatureMap& fmap, bool with_stats, + std::string format) const { + std::vector dump; + for (size_t i = 0; i < trees.size(); i++) { + dump.push_back(trees[i]->DumpModel(fmap, with_stats, format)); + } + return dump; + } + void CommitModel(std::vector >&& new_trees, + int bst_group) { + size_t old_ntree = trees.size(); + for (size_t i = 0; i < new_trees.size(); ++i) { + trees.push_back(std::move(new_trees[i])); + tree_info.push_back(bst_group); + } + param.num_trees += static_cast(new_trees.size()); + } + + // base margin + bst_float base_margin; + // model parameter + GBTreeModelParam param; + /*! \brief vector of trees stored in the model */ + std::vector > trees; + /*! \brief for the update process, a place to keep the initial trees */ + std::vector > trees_to_update; + /*! \brief some information indicator of the tree, reserved */ + std::vector tree_info; +}; +} // namespace gbm +} // namespace xgboost diff --git a/src/learner.cc b/src/learner.cc index d26e0d68237d..9e225b03104a 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -517,7 +517,7 @@ class LearnerImpl : public Learner { unsigned ntree_limit = 0) const { CHECK(gbm_.get() != nullptr) << "Predict must happen after Load or InitModel"; - gbm_->Predict(data, out_preds, ntree_limit); + gbm_->PredictBatch(data, out_preds, ntree_limit); } // model parameter LearnerModelParam mparam; diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc new file mode 100644 index 000000000000..01ffe7bf5720 --- /dev/null +++ b/src/predictor/cpu_predictor.cc @@ -0,0 +1,314 @@ +/*! + * Copyright by Contributors 2017 + */ +#include +#include +#include +#include "dmlc/logging.h" + +namespace xgboost { +namespace predictor { + +class CPUPredictor : public Predictor { + protected: + static bst_float PredValue(const RowBatch::Inst& inst, + const std::vector>& trees, + const std::vector& tree_info, int bst_group, + unsigned root_index, RegTree::FVec* p_feats, + unsigned tree_begin, unsigned tree_end) { + bst_float psum = 0.0f; + p_feats->Fill(inst); + for (size_t i = tree_begin; i < tree_end; ++i) { + if (tree_info[i] == bst_group) { + int tid = trees[i]->GetLeafIndex(*p_feats, root_index); + psum += (*trees[i])[tid].leaf_value(); + } + } + p_feats->Drop(inst); + return psum; + } + + void InitOutPredictions(const MetaInfo& info, + std::vector* out_preds, + const gbm::GBTreeModel& model) const { + size_t n = model.param.num_output_group * info.num_row; + const std::vector& base_margin = info.base_margin; + out_preds->resize(n); + if (base_margin.size() != 0) { + CHECK_EQ(out_preds->size(), n); + std::copy(base_margin.begin(), base_margin.end(), out_preds->begin()); + } else { + std::fill(out_preds->begin(), out_preds->end(), model.base_margin); + } + } + // init thread buffers + inline void InitThreadTemp(int nthread, int num_feature) { + int prev_thread_temp_size = thread_temp.size(); + if (prev_thread_temp_size < nthread) { + thread_temp.resize(nthread, RegTree::FVec()); + for (int i = prev_thread_temp_size; i < nthread; ++i) { + thread_temp[i].Init(num_feature); + } + } + } + inline void PredLoopSpecalize(DMatrix* p_fmat, + std::vector* out_preds, + const gbm::GBTreeModel& model, int num_group, + unsigned tree_begin, unsigned tree_end) { + const MetaInfo& info = p_fmat->info(); + const int nthread = omp_get_max_threads(); + InitThreadTemp(nthread, model.param.num_feature); + std::vector& preds = *out_preds; + CHECK_EQ(model.param.size_leaf_vector, 0) + << "size_leaf_vector is enforced to 0 so far"; + CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group); + // start collecting the prediction + dmlc::DataIter* iter = p_fmat->RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + // parallel over local batch + const int K = 8; + const bst_omp_uint nsize = static_cast(batch.size); + const bst_omp_uint rest = nsize % K; +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize - rest; i += K) { + const int tid = omp_get_thread_num(); + RegTree::FVec& feats = thread_temp[tid]; + int64_t ridx[K]; + RowBatch::Inst inst[K]; + for (int k = 0; k < K; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); + } + for (int k = 0; k < K; ++k) { + inst[k] = batch[i + k]; + } + for (int k = 0; k < K; ++k) { + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += this->PredValue( + inst[k], model.trees, model.tree_info, gid, + info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); + } + } + } + for (bst_omp_uint i = nsize - rest; i < nsize; ++i) { + RegTree::FVec& feats = thread_temp[0]; + const int64_t ridx = static_cast(batch.base_rowid + i); + const RowBatch::Inst inst = batch[i]; + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx * num_group + gid; + preds[offset] += + this->PredValue(inst, model.trees, model.tree_info, gid, + info.GetRoot(ridx), &feats, tree_begin, tree_end); + } + } + } + } + + /** + * \fn bool PredictFromCache(DMatrix* dmat, std::vector* + * out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit = 0) + * + * \brief Attempt to predict from cache. + * + * \return True if it succeeds, false if it fails. + */ + bool PredictFromCache(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, + unsigned ntree_limit = 0) { + if (ntree_limit == 0 || + ntree_limit * model.param.num_output_group >= model.trees.size()) { + auto it = cache_.find(dmat); + if (it != cache_.end()) { + std::vector& y = it->second.predictions; + if (y.size() != 0) { + out_preds->resize(y.size()); + std::copy(y.begin(), y.end(), out_preds->begin()); + return true; + } + } + } + + return false; + } + + void PredLoopInternal(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit) { + // TODO(Rory): Check if this specialisation actually improves performance + if (model.param.num_output_group == 1) { + PredLoopSpecalize(dmat, out_preds, model, 1, tree_begin, ntree_limit); + } else { + PredLoopSpecalize(dmat, out_preds, model, model.param.num_output_group, + tree_begin, ntree_limit); + } + } + + public: + void PredictBatch(DMatrix* dmat, std::vector* out_preds, + const gbm::GBTreeModel& model, int tree_begin, + unsigned ntree_limit = 0) override { + if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { + return; + } + + this->InitOutPredictions(dmat->info(), out_preds, model); + + ntree_limit *= model.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model.trees.size()) { + ntree_limit = static_cast(model.trees.size()); + } + + this->PredLoopInternal(dmat, out_preds, model, tree_begin, ntree_limit); + } + + void UpdatePredictionCache( + const gbm::GBTreeModel& model, + std::vector>* updaters, + int num_new_trees) override { + int old_ntree = model.trees.size() - num_new_trees; + // update cache entry + for (auto& kv : cache_) { + PredictionCacheEntry& e = kv.second; + + if (e.predictions.size() == 0) { + InitOutPredictions(e.data->info(), &(e.predictions), model); + PredLoopInternal(e.data.get(), &(e.predictions), model, 0, + model.trees.size()); + } else if (model.param.num_output_group == 1 && updaters->size() > 0 && + num_new_trees == 1 && + updaters->back()->UpdatePredictionCache(e.data.get(), + &(e.predictions))) { + {} // do nothing + } else { + PredLoopInternal(e.data.get(), &(e.predictions), model, old_ntree, + model.trees.size()); + } + } + } + + void PredictInstance(const SparseBatch::Inst& inst, + std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit, + unsigned root_index) override { + if (thread_temp.size() == 0) { + thread_temp.resize(1, RegTree::FVec()); + thread_temp[0].Init(model.param.num_feature); + } + ntree_limit *= model.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model.trees.size()) { + ntree_limit = static_cast(model.trees.size()); + } + out_preds->resize(model.param.num_output_group * + (model.param.size_leaf_vector + 1)); + // loop over output groups + for (int gid = 0; gid < model.param.num_output_group; ++gid) { + (*out_preds)[gid] = + PredValue(inst, model.trees, model.tree_info, gid, root_index, + &thread_temp[0], 0, ntree_limit) + + model.base_margin; + } + } + void PredictLeaf(DMatrix* p_fmat, std::vector* out_preds, + const gbm::GBTreeModel& model, unsigned ntree_limit) override { + const int nthread = omp_get_max_threads(); + InitThreadTemp(nthread, model.param.num_feature); + const MetaInfo& info = p_fmat->info(); + // number of valid trees + ntree_limit *= model.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model.trees.size()) { + ntree_limit = static_cast(model.trees.size()); + } + std::vector& preds = *out_preds; + preds.resize(info.num_row * ntree_limit); + // start collecting the prediction + dmlc::DataIter* iter = p_fmat->RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + // parallel over local batch + const bst_omp_uint nsize = static_cast(batch.size); +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + const int tid = omp_get_thread_num(); + size_t ridx = static_cast(batch.base_rowid + i); + RegTree::FVec& feats = thread_temp[tid]; + feats.Fill(batch[i]); + for (unsigned j = 0; j < ntree_limit; ++j) { + int tid = model.trees[j]->GetLeafIndex(feats, info.GetRoot(ridx)); + preds[ridx * ntree_limit + j] = static_cast(tid); + } + feats.Drop(batch[i]); + } + } + } + + void PredictContribution(DMatrix* p_fmat, + std::vector* out_contribs, + const gbm::GBTreeModel& model, unsigned ntree_limit) override { + const int nthread = omp_get_max_threads(); + InitThreadTemp(nthread, model.param.num_feature); + const MetaInfo& info = p_fmat->info(); + // number of valid trees + ntree_limit *= model.param.num_output_group; + if (ntree_limit == 0 || ntree_limit > model.trees.size()) { + ntree_limit = static_cast(model.trees.size()); + } + const int ngroup = model.param.num_output_group; + size_t ncolumns = model.param.num_feature + 1; + // allocate space for (number of features + bias) times the number of rows + std::vector& contribs = *out_contribs; + contribs.resize(info.num_row * ncolumns * model.param.num_output_group); + // make sure contributions is zeroed, we could be reusing a previously + // allocated one + std::fill(contribs.begin(), contribs.end(), 0); +// initialize tree node mean values +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < ntree_limit; ++i) { + model.trees[i]->FillNodeMeanValues(); + } + // start collecting the contributions + dmlc::DataIter* iter = p_fmat->RowIterator(); + const std::vector& base_margin = info.base_margin; + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + // parallel over local batch + const bst_omp_uint nsize = static_cast(batch.size); +#pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + size_t row_idx = static_cast(batch.base_rowid + i); + unsigned root_id = info.GetRoot(row_idx); + RegTree::FVec& feats = thread_temp[omp_get_thread_num()]; + // loop over all classes + for (int gid = 0; gid < ngroup; ++gid) { + bst_float* p_contribs = + &contribs[(row_idx * ngroup + gid) * ncolumns]; + feats.Fill(batch[i]); + // calculate contributions + for (unsigned j = 0; j < ntree_limit; ++j) { + if (model.tree_info[j] != gid) { + continue; + } + model.trees[j]->CalculateContributions(feats, root_id, p_contribs); + } + feats.Drop(batch[i]); + // add base margin to BIAS + if (base_margin.size() != 0) { + p_contribs[ncolumns - 1] += base_margin[row_idx * ngroup + gid]; + } else { + p_contribs[ncolumns - 1] += model.base_margin; + } + } + } + } + } + std::vector thread_temp; +}; + +XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") + .describe("Make predictions using CPU.") + .set_body([]() { return new CPUPredictor(); }); +} // namespace predictor +} // namespace xgboost diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc new file mode 100644 index 000000000000..82200771a760 --- /dev/null +++ b/src/predictor/predictor.cc @@ -0,0 +1,25 @@ +/*! + * Copyright by Contributors 2017 + */ +#include +#include + +namespace dmlc { +DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); +} // namespace dmlc +namespace xgboost { +void Predictor::InitCache(const std::vector >& cache) { + for (const std::shared_ptr& d : cache) { + PredictionCacheEntry e; + e.data = d; + cache_[d.get()] = std::move(e); + } +} +Predictor* Predictor::Create(std::string name) { + auto* e = ::dmlc::Registry::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown predictor type " << name; + } + return (e->body)(); +} +} // namespace xgboost