diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 19a484109988..40136c5c7991 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -8,11 +8,11 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ -#include #include #include #include "hist_util.h" + namespace xgboost { namespace common { @@ -51,10 +51,6 @@ class Column { } const size_t* GetRowData() const { return row_ind_; } - const uint32_t* GetIndex() const { - return index_; - } - private: ColumnType type_; const uint32_t* index_; @@ -117,6 +113,7 @@ class ColumnMatrix { boundary_[fid].index_end = accum_index_; boundary_[fid].row_ind_end = accum_row_ind_; } + index_.resize(boundary_[nfeature - 1].index_end); row_ind_.resize(boundary_[nfeature - 1].row_ind_end); diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 83e8c117ba54..e9390f2172e3 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -1,6 +1,6 @@ /*! * Copyright 2017-2019 by Contributors - * \file hist_util.cc + * \file hist_util.cc */ #include #include @@ -418,8 +418,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { } #pragma omp parallel for num_threads(nthread) schedule(static) - for (int32_t idx = 0; idx < int32_t(nbins); ++idx) { - for (int32_t tid = 0; tid < nthread; ++tid) { + for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) { + for (size_t tid = 0; tid < nthread; ++tid) { hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; } } @@ -569,7 +569,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, for (auto fid : group) { nnz += feature_nnz[fid]; } - float nnz_rate = static_cast(nnz) / nrow; + double nnz_rate = static_cast(nnz) / nrow; // take apart small sparse group, due it will not gain on speed if (nnz_rate <= param.sparse_threshold) { for (auto fid : group) { @@ -654,129 +654,151 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, } } -// used when data layout is kDenseDataZeroBased or kDenseDataOneBased -// it means that "row_ptr" is not needed for hist computations -void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid, - const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, - GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) { - GradStatHist grad_stat; // make local var to prevent false sharing +void GHistBuilder::BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist) { + const size_t nthread = static_cast(this->nthread_); + data_.resize(nbins_ * nthread_); + + const size_t* rid = row_indices.begin; + const size_t nrows = row_indices.Size(); + const uint32_t* index = gmat.index.data(); + const size_t* row_ptr = gmat.row_ptr.data(); + const float* pgh = reinterpret_cast(gpair.data()); + + double* hist_data = reinterpret_cast(hist.data()); + double* data = reinterpret_cast(data_.data()); + + const size_t block_size = 512; + size_t n_blocks = nrows/block_size; + n_blocks += !!(nrows - n_blocks*block_size); + + const size_t nthread_to_process = std::min(nthread, n_blocks); + memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t)); - const size_t n_features = row_ptr[rid[istart]+1] - row_ptr[rid[istart]]; const size_t cache_line_size = 64; - const size_t prefetch_step = cache_line_size / sizeof(*index); const size_t prefetch_offset = 10; - size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; - // if read each row in some block of bin-matrix - it's dense block - // and we dont need SW prefetch in this case - const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart - 1); +#pragma omp parallel for num_threads(nthread_to_process) schedule(guided) + for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { + dmlc::omp_uint tid = omp_get_thread_num(); + double* data_local_hist = ((nthread_to_process == 1) ? hist_data : + reinterpret_cast(data_.data() + tid * nbins_)); - if (iend < nrows - no_prefetch_size && !denseBlock) { - for (size_t i = istart; i < iend; ++i) { - const size_t icol_start = rid[i] * n_features; - const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features; - const size_t idx_gh = 2*rid[i]; + if (!thread_init_[tid]) { + memset(data_local_hist, '\0', 2*nbins_*sizeof(double)); + thread_init_[tid] = true; + } - PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + const size_t istart = iblock*block_size; + const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size); + for (size_t i = istart; i < iend; ++i) { + const size_t icol_start = row_ptr[rid[i]]; + const size_t icol_end = row_ptr[rid[i]+1]; - for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features; - j += prefetch_step) { - PREFETCH_READ_T0(index + j); + if (i < nrows - no_prefetch_size) { + PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]); + PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); } - grad_stat.sum_grad += pgh[idx_gh]; - grad_stat.sum_hess += pgh[idx_gh+1]; - - for (size_t j = icol_start; j < icol_start + n_features; ++j) { + for (size_t j = icol_start; j < icol_end; ++j) { const uint32_t idx_bin = 2*index[j]; + const size_t idx_gh = 2*rid[i]; + data_local_hist[idx_bin] += pgh[idx_gh]; data_local_hist[idx_bin+1] += pgh[idx_gh+1]; } } - } else { - for (size_t i = istart; i < iend; ++i) { - const size_t icol_start = rid[i] * n_features; - const size_t idx_gh = 2*rid[i]; - grad_stat.sum_grad += pgh[idx_gh]; - grad_stat.sum_hess += pgh[idx_gh+1]; - - for (size_t j = icol_start; j < icol_start + n_features; ++j) { - const uint32_t idx_bin = 2*index[j]; - data_local_hist[idx_bin] += pgh[idx_gh]; - data_local_hist[idx_bin+1] += pgh[idx_gh+1]; - } - } } - grad_stat_global->Add(grad_stat); -} - -// used when data layout is kSparseData -// it means that "row_ptr" is needed for hist computations -void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid, - const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, - GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) { - GradStatHist grad_stat; // make local var to prevent false sharing - - const size_t cache_line_size = 64; - const size_t prefetch_step = cache_line_size / sizeof(index[0]); - const size_t prefetch_offset = 10; - - size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); - no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; - // if read each row in some block of bin-matrix - it's dense block - // and we dont need SW prefetch in this case - const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart); + if (nthread_to_process > 1) { + const size_t size = (2*nbins_); + const size_t block_size = 1024; + size_t n_blocks = size/block_size; + n_blocks += !!(size - n_blocks*block_size); - if (iend < nrows - no_prefetch_size && !denseBlock) { - for (size_t i = istart; i < iend; ++i) { - const size_t icol_start = row_ptr[rid[i]]; - const size_t icol_end = row_ptr[rid[i]+1]; - const size_t idx_gh = 2*rid[i]; + size_t n_worked_bins = 0; + for (size_t i = 0; i < nthread_to_process; ++i) { + if (thread_init_[i]) { + thread_init_[n_worked_bins++] = i; + } + } - const size_t icol_start10 = row_ptr[rid[i+prefetch_offset]]; - const size_t icol_end10 = row_ptr[rid[i+prefetch_offset]+1]; +#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) + for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { + const size_t istart = iblock * block_size; + const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size); - PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + const size_t bin = 2 * thread_init_[0] * nbins_; + memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart)); - for (size_t j = icol_start10; j < icol_end10; j+=prefetch_step) { - PREFETCH_READ_T0(index + j); + for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) { + const size_t bin = 2 * thread_init_[i_bin_part] * nbins_; + for (size_t i = istart; i < iend; i++) { + hist_data[i] += data[bin + i]; + } } + } + } +} - grad_stat.sum_grad += pgh[idx_gh]; - grad_stat.sum_hess += pgh[idx_gh+1]; - - for (size_t j = icol_start; j < icol_end; ++j) { - const uint32_t idx_bin = 2*index[j]; - data_local_hist[idx_bin] += pgh[idx_gh]; - data_local_hist[idx_bin+1] += pgh[idx_gh+1]; +void GHistBuilder::BuildBlockHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist) { + constexpr int kUnroll = 8; // loop unrolling factor + const size_t nblock = gmatb.GetNumBlock(); + const size_t nrows = row_indices.end - row_indices.begin; + const size_t rest = nrows % kUnroll; + +#if defined(_OPENMP) + const auto nthread = static_cast(this->nthread_); // NOLINT +#endif // defined(_OPENMP) + tree::GradStats* p_hist = hist.data(); + +#pragma omp parallel for num_threads(nthread) schedule(guided) + for (bst_omp_uint bid = 0; bid < nblock; ++bid) { + auto gmat = gmatb[bid]; + + for (size_t i = 0; i < nrows - rest; i += kUnroll) { + size_t rid[kUnroll]; + size_t ibegin[kUnroll]; + size_t iend[kUnroll]; + GradientPair stat[kUnroll]; + + for (int k = 0; k < kUnroll; ++k) { + rid[k] = row_indices.begin[i + k]; + ibegin[k] = gmat.row_ptr[rid[k]]; + iend[k] = gmat.row_ptr[rid[k] + 1]; + stat[k] = gpair[rid[k]]; + } + for (int k = 0; k < kUnroll; ++k) { + for (size_t j = ibegin[k]; j < iend[k]; ++j) { + const uint32_t bin = gmat.index[j]; + p_hist[bin].Add(stat[k]); + } } } - } else { - for (size_t i = istart; i < iend; ++i) { - const size_t icol_start = row_ptr[rid[i]]; - const size_t icol_end = row_ptr[rid[i]+1]; - const size_t idx_gh = 2*rid[i]; - - grad_stat.sum_grad += pgh[idx_gh]; - grad_stat.sum_hess += pgh[idx_gh+1]; - - for (size_t j = icol_start; j < icol_end; ++j) { - const uint32_t idx_bin = 2*index[j]; - data_local_hist[idx_bin] += pgh[idx_gh]; - data_local_hist[idx_bin+1] += pgh[idx_gh+1]; + for (size_t i = nrows - rest; i < nrows; ++i) { + const size_t rid = row_indices.begin[i]; + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; + const GradientPair stat = gpair[rid]; + for (size_t j = ibegin; j < iend; ++j) { + const uint32_t bin = gmat.index[j]; + p_hist[bin].Add(stat); } } } - grad_stat_global->Add(grad_stat); } -void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { - GradStatHist* p_self = self.data(); - GradStatHist* p_sibling = sibling.data(); - GradStatHist* p_parent = parent.data(); +void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { + tree::GradStats* p_self = self.data(); + tree::GradStats* p_sibling = sibling.data(); + tree::GradStats* p_parent = parent.data(); const size_t size = self.size(); CHECK_EQ(sibling.size(), size); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 176f1b495901..b7ccd54adf8a 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -19,45 +19,9 @@ #include "../tree/param.h" #include "./quantile.h" #include "./timer.h" -#include "random.h" +#include "../include/rabit/rabit.h" namespace xgboost { - -/*! - * \brief A C-style array with in-stack allocation. As long as the array is smaller than - * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be - * heap-allocated. - */ -template -class MemStackAllocator { - public: - explicit MemStackAllocator(size_t required_size): required_size_(required_size) { - } - - T* Get() { - if (!ptr_) { - if (MaxStackSize >= required_size_) { - ptr_ = stack_mem_; - } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); - do_free_ = true; - } - } - - return ptr_; - } - - ~MemStackAllocator() { - if (do_free_) free(ptr_); - } - - private: - T* ptr_ = nullptr; - bool do_free_ = false; - size_t required_size_; - T stack_mem_[MaxStackSize]; -}; - namespace common { /* @@ -287,7 +251,7 @@ class DenseCuts : public CutsBuilder { // FIXME(trivialfis): Merge this into generic cut builder. /*! \brief Builds the cut matrix on the GPU. - * + * * \return The row stride across the entire dataset. */ size_t DeviceSketch(int device, @@ -303,10 +267,9 @@ size_t DeviceSketch(int device, */ struct GHistIndexMatrix { /*! \brief row pointer to rows by element position */ - // std::vector row_ptr; - SimpleArray row_ptr; + std::vector row_ptr; /*! \brief The index data */ - SimpleArray index; + std::vector index; /*! \brief hit count of each index */ std::vector hit_count; /*! \brief The corresponding cuts */ @@ -377,63 +340,12 @@ class GHistIndexBlockMatrix { }; /*! - * \brief used instead of GradStats to have float instead of double to reduce histograms - * this improves performance by 10-30% and memory consumption for histograms by 2x - * accuracy in both cases is the same + * \brief histogram of graident statistics for a single node. + * Consists of multiple GradStats, each entry showing total graident statistics + * for that particular bin + * Uses global bin id so as to represent all features simultaneously */ -struct GradStatHist { - typedef float GradType; - /*! \brief sum gradient statistics */ - GradType sum_grad; - /*! \brief sum hessian statistics */ - GradType sum_hess; - - GradStatHist() : sum_grad{0}, sum_hess{0} { - static_assert(sizeof(GradStatHist) == 8, - "Size of GradStatHist is not 8 bytes."); - } - - inline void Add(const GradStatHist& b) { - sum_grad += b.sum_grad; - sum_hess += b.sum_hess; - } - - inline void Add(const tree::GradStats& b) { - sum_grad += b.sum_grad; - sum_hess += b.sum_hess; - } - - inline void Add(const GradientPair& p) { - this->Add(p.GetGrad(), p.GetHess()); - } - - inline void Add(const GradType& grad, const GradType& hess) { - sum_grad += grad; - sum_hess += hess; - } - - inline tree::GradStats ToGradStat() const { - return tree::GradStats(sum_grad, sum_hess); - } - - inline void SetSubstract(const GradStatHist& a, const GradStatHist& b) { - sum_grad = a.sum_grad - b.sum_grad; - sum_hess = a.sum_hess - b.sum_hess; - } - - inline void SetSubstract(const tree::GradStats& a, const GradStatHist& b) { - sum_grad = a.sum_grad - b.sum_grad; - sum_hess = a.sum_hess - b.sum_hess; - } - - inline GradType GetGrad() const { return sum_grad; } - inline GradType GetHess() const { return sum_hess; } - inline static void Reduce(GradStatHist& a, const GradStatHist& b) { // NOLINT(*) - a.Add(b); - } -}; - -using GHistRow = Span; +using GHistRow = Span; /*! * \brief histogram of gradient statistics for multiple nodes @@ -441,42 +353,48 @@ using GHistRow = Span; class HistCollection { public: // access histogram for i-th node - inline GHistRow operator[](bst_uint nid) { - AddHistRow(nid); - return { const_cast(dmlc::BeginPtr(data_arr_[nid])), nbins_}; + GHistRow operator[](bst_uint nid) const { + constexpr uint32_t kMax = std::numeric_limits::max(); + CHECK_NE(row_ptr_[nid], kMax); + tree::GradStats* ptr = + const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]); + return {ptr, nbins_}; } // have we computed a histogram for i-th node? - inline bool RowExists(bst_uint nid) const { - return nid < data_arr_.size(); + bool RowExists(bst_uint nid) const { + const uint32_t k_max = std::numeric_limits::max(); + return (nid < row_ptr_.size() && row_ptr_[nid] != k_max); } // initialize histogram collection - inline void Init(uint32_t nbins) { - if (nbins_ != nbins) { - data_arr_.clear(); - nbins_ = nbins; - } + void Init(uint32_t nbins) { + nbins_ = nbins; + row_ptr_.clear(); + data_.clear(); } // create an empty histogram for i-th node - inline void AddHistRow(bst_uint nid) { - if (data_arr_.size() <= nid) { - size_t prev = data_arr_.size(); - data_arr_.resize(nid + 1); - - for (size_t i = prev; i < data_arr_.size(); ++i) { - data_arr_[i].resize(nbins_); - } + void AddHistRow(bst_uint nid) { + constexpr uint32_t kMax = std::numeric_limits::max(); + if (nid >= row_ptr_.size()) { + row_ptr_.resize(nid + 1, kMax); } + CHECK_EQ(row_ptr_[nid], kMax); + + row_ptr_[nid] = data_.size(); + data_.resize(data_.size() + nbins_); } private: /*! \brief number of all bins over all features */ - uint32_t nbins_ = 0; - std::vector> data_arr_; -}; + uint32_t nbins_; + std::vector data_; + + /*! \brief row_ptr_[nid] locates bin for historgram of node nid */ + std::vector row_ptr_; +}; /*! * \brief builder for histograms of gradient statistics @@ -487,55 +405,21 @@ class GHistBuilder { inline void Init(size_t nthread, uint32_t nbins) { nthread_ = nthread; nbins_ = nbins; + thread_init_.resize(nthread_); } + // construct a histogram via histogram aggregation + void BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist); + // same, with feature grouping void BuildBlockHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexBlockMatrix& gmatb, - GHistRow hist) { - constexpr int kUnroll = 8; // loop unrolling factor - const int32_t nblock = gmatb.GetNumBlock(); - const size_t nrows = row_indices.end - row_indices.begin; - const size_t rest = nrows % kUnroll; - - #pragma omp parallel for - for (int32_t bid = 0; bid < nblock; ++bid) { - auto gmat = gmatb[bid]; - - for (size_t i = 0; i < nrows - rest; i += kUnroll) { - size_t rid[kUnroll]; - size_t ibegin[kUnroll]; - size_t iend[kUnroll]; - GradientPair stat[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - rid[k] = row_indices.begin[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - ibegin[k] = gmat.row_ptr[rid[k]]; - iend[k] = gmat.row_ptr[rid[k] + 1]; - } - for (int k = 0; k < kUnroll; ++k) { - stat[k] = gpair[rid[k]]; - } - for (int k = 0; k < kUnroll; ++k) { - for (size_t j = ibegin[k]; j < iend[k]; ++j) { - const uint32_t bin = gmat.index[j]; - hist[bin].Add(stat[k]); - } - } - } - for (size_t i = nrows - rest; i < nrows; ++i) { - const size_t rid = row_indices.begin[i]; - const size_t ibegin = gmat.row_ptr[rid]; - const size_t iend = gmat.row_ptr[rid + 1]; - const GradientPair stat = gpair[rid]; - for (size_t j = ibegin; j < iend; ++j) { - const uint32_t bin = gmat.index[j]; - hist[bin].Add(stat); - } - } - } - } + const RowSetCollection::Elem row_indices, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist); + // construct a histogram via subtraction trick + void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); uint32_t GetNumBins() { return nbins_; @@ -546,19 +430,11 @@ class GHistBuilder { size_t nthread_; /*! \brief number of all bins over all features */ uint32_t nbins_; + std::vector thread_init_; + std::vector data_; }; -void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid, - const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, - GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat); - -void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid, - const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, - GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat); - -void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); - } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_HIST_UTIL_H_ diff --git a/src/common/row_set.h b/src/common/row_set.h index 39ae404f8779..285988b159c3 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -27,10 +27,10 @@ class RowSetCollection { // id of node associated with this instance set; -1 means uninitialized Elem() = default; - Elem(const size_t* begin_, - const size_t* end_, - int node_id_) - : begin(begin_), end(end_), node_id(node_id_) {} + Elem(const size_t* begin, + const size_t* end, + int node_id) + : begin(begin), end(end), node_id(node_id) {} inline size_t Size() const { return end - begin; @@ -42,10 +42,6 @@ class RowSetCollection { std::vector right; }; - size_t Size(unsigned node_id) { - return elem_of_each_node_[node_id].Size(); - } - inline std::vector::const_iterator begin() const { // NOLINT return elem_of_each_node_.begin(); } @@ -55,12 +51,12 @@ class RowSetCollection { } /*! \brief return corresponding element set given the node_id */ - inline Elem operator[](unsigned node_id) const { - const Elem e = elem_of_each_node_[node_id]; + inline const Elem& operator[](unsigned node_id) const { + const Elem& e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr) + << "access element that is not in the set"; return e; } - - // clear up things inline void Clear() { elem_of_each_node_.clear(); @@ -85,29 +81,38 @@ class RowSetCollection { const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); elem_of_each_node_.emplace_back(Elem(begin, end, 0)); } - // split rowset into two inline void AddSplit(unsigned node_id, - size_t iLeft, + const std::vector& row_split_tloc, unsigned left_node_id, unsigned right_node_id) { - Elem e = elem_of_each_node_[node_id]; - + const Elem e = elem_of_each_node_[node_id]; + const auto nthread = static_cast(row_split_tloc.size()); CHECK(e.begin != nullptr); + size_t* all_begin = dmlc::BeginPtr(row_indices_); + size_t* begin = all_begin + (e.begin - all_begin); - size_t* begin = const_cast(e.begin); - size_t* split_pt = begin + iLeft; + size_t* it = begin; + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); + it += row_split_tloc[tid].left.size(); + } + size_t* split_pt = it; + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); + it += row_split_tloc[tid].right.size(); + } if (left_node_id >= elem_of_each_node_.size()) { - elem_of_each_node_.resize((left_node_id + 1)*2, Elem(nullptr, nullptr, -1)); + elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); } if (right_node_id >= elem_of_each_node_.size()) { - elem_of_each_node_.resize((right_node_id + 1)*2, Elem(nullptr, nullptr, -1)); + elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); } elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id); elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id); - elem_of_each_node_[node_id] = Elem(begin, e.end, -1); + elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); } // stores the row indices in the set diff --git a/src/tree/param.h b/src/tree/param.h index 0ca3ce472aef..2cebb3eec781 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -281,7 +281,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess } } else { T w = CalcWeight(p, sum_grad, sum_hess); - T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w); + T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w); if (p.reg_alpha == 0.0f) { return ret; } else { @@ -301,7 +301,7 @@ template XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, T test_grad, T test_hess) { T w = CalcWeight(sum_grad, sum_hess); - T ret = CalcGainGivenWeight(p, test_grad, test_hess); + T ret = CalcGainGivenWeight(p, test_grad, test_hess); if (p.reg_alpha == 0.0f) { return ret; } else { @@ -340,16 +340,15 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) } /*! \brief core statistics used for tree construction */ -struct GradStats { - typedef double GradType; +struct XGBOOST_ALIGNAS(16) GradStats { /*! \brief sum gradient statistics */ - GradType sum_grad; + double sum_grad; /*! \brief sum hessian statistics */ - GradType sum_hess; + double sum_hess; public: - XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; } - XGBOOST_DEVICE GradType GetHess() const { return sum_hess; } + XGBOOST_DEVICE double GetGrad() const { return sum_grad; } + XGBOOST_DEVICE double GetHess() const { return sum_hess; } XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} { static_assert(sizeof(GradStats) == 16, @@ -359,7 +358,7 @@ struct GradStats { template XGBOOST_DEVICE explicit GradStats(const GpairT &sum) : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - explicit GradStats(const GradType grad, const GradType hess) + explicit GradStats(const double grad, const double hess) : sum_grad(grad), sum_hess(hess) {} /*! * \brief accumulate statistics @@ -384,7 +383,7 @@ struct GradStats { /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } /*! \brief add statistics to the data */ - inline void Add(GradType grad, GradType hess) { + inline void Add(double grad, double hess) { sum_grad += grad; sum_hess += hess; } @@ -402,7 +401,6 @@ struct SplitEntry { bst_float split_value{0.0f}; GradStats left_sum; GradStats right_sum; - bool default_left{true}; /*! \brief constructor */ SplitEntry() = default; @@ -417,11 +415,7 @@ struct SplitEntry { * \param split_index the feature index where the split is on */ inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { - if (!std::isfinite(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, - // for example when lambda = 0 & min_child_weight = 0 - // skip value in this case - return false; - } else if (this->SplitIndex() <= split_index) { + if (this->SplitIndex() <= split_index) { return new_loss_chg > this->loss_chg; } else { return !(this->loss_chg > new_loss_chg); @@ -439,7 +433,6 @@ struct SplitEntry { this->split_value = e.split_value; this->left_sum = e.left_sum; this->right_sum = e.right_sum; - this->default_left = e.default_left; return true; } else { return false; @@ -454,11 +447,13 @@ struct SplitEntry { * \return whether the proposed split is better and can replace current split */ inline bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool new_default_left, + bst_float new_split_value, bool default_left, const GradStats &left_sum, const GradStats &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; - this->default_left = new_default_left; + if (default_left) { + split_index |= (1U << 31); + } this->sindex = split_index; this->split_value = new_split_value; this->left_sum = left_sum; @@ -474,9 +469,9 @@ struct SplitEntry { dst.Update(src); } /*!\return feature index to split on */ - inline unsigned SplitIndex() const { return sindex; } + inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); } /*!\return whether missing value goes to left branch */ - inline bool DefaultLeft() const { return default_left; } + inline bool DefaultLeft() const { return (sindex >> 31) != 0; } }; } // namespace tree diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index a17127774f2f..ca3aeda71154 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -284,9 +284,7 @@ class MonotonicConstraint final : public SplitEvaluator { bst_float leftweight, bst_float rightweight) override { inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight); - - bst_uint newsize = std::max(bst_uint(lower_.size()), bst_uint(std::max(leftid, rightid) + 1u)); - + bst_uint newsize = std::max(leftid, rightid) + 1; lower_.resize(newsize); upper_.resize(newsize); bst_int constraint = GetConstraint(featureid); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 71b9a3274307..afdf15cb9f9d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1,8 +1,8 @@ /*! - * Copyright 2017-2019 by Contributors + * Copyright 2017-2018 by Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree - * \author Philip Cho, Tianqi Checn, Egor Smirnov + * \author Philip Cho, Tianqi Checn */ #include #include @@ -41,7 +41,7 @@ void QuantileHistMaker::Configure(const Args& args) { param_.UpdateAllowUnknown(args); is_gmat_initialized_ = false; - // initialize the split evaluator + // initialise the split evaluator if (!spliteval_) { spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); } @@ -52,7 +52,6 @@ void QuantileHistMaker::Configure(const Args& args) { void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { - // omp_set_nested(1); if (is_gmat_initialized_ == false) { double tstart = dmlc::GetTime(); gmat_.Init(dmat, static_cast(param_.max_bin)); @@ -89,231 +88,94 @@ bool QuantileHistMaker::UpdatePredictionCache( } } -void QuantileHistMaker::Builder::BuildNodeStat( - const GHistIndexMatrix &gmat, - DMatrix *p_fmat, - RegTree *p_tree, - const std::vector &gpair_h, - int32_t nid) { - - // add constraints - if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { - auto parent_id = (*p_tree)[nid].Parent(); - // it's a right child - auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); - auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); - - spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id, - snode_[left_sibling_id].weight, snode_[nid].weight); - } +void QuantileHistMaker::Builder::SyncHistograms( + int starting_index, + int sync_count, + RegTree *p_tree) { + builder_monitor_.Start("SyncHistograms"); + this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count); + // use Subtraction Trick + for (auto const& node_pair : nodes_for_subtraction_trick_) { + hist_.AddHistRow(node_pair.first); + SubtractionTrick(hist_[node_pair.first], hist_[node_pair.second], + hist_[(*p_tree)[node_pair.first].Parent()]); + } + builder_monitor_.Stop("SyncHistograms"); } -void QuantileHistMaker::Builder::BuildNodeStatBatch( +void QuantileHistMaker::Builder::BuildLocalHistograms( + int *starting_index, + int *sync_count, const GHistIndexMatrix &gmat, - DMatrix *p_fmat, + const GHistIndexBlockMatrix &gmatb, RegTree *p_tree, - const std::vector &gpair_h, - const std::vector& nodes) { - perf_monitor.TickStart(); - for (const auto& node : nodes) { - const int32_t nid = node.nid; - const int32_t sibling_nid = node.sibling_nid; - this->InitNewNode(nid, gmat, gpair_h, *p_fmat, p_tree, &(snode_[nid]), (*p_tree)[nid].Parent()); - if (sibling_nid > -1) { - this->InitNewNode(nid, gmat, gpair_h, *p_fmat, p_tree, - &(snode_[sibling_nid]), (*p_tree)[sibling_nid].Parent()); - } - } - for (const auto& node : nodes) { - const int32_t nid = node.nid; - const int32_t sibling_nid = node.sibling_nid; - BuildNodeStat(gmat, p_fmat, p_tree, gpair_h, nid); - if (sibling_nid > -1) { - BuildNodeStat(gmat, p_fmat, p_tree, gpair_h, sibling_nid); - } - } - perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::INIT_NEW_NODE); -} - -template -inline std::pair PartitionDenseLeftDefaultKernel(const RowIdxType* rid, - const IdxType* idx, const IdxType offset, const int32_t split_cond, - const size_t istart, const size_t iend, RowIdxType* p_left, RowIdxType* p_right) { - size_t ileft = 0; - size_t iright = 0; - - const IdxType max_val = std::numeric_limits::max(); - - for (size_t i = istart; i < iend; i++) { - if (idx[rid[i]] == max_val || static_cast(idx[rid[i]] + offset) <= split_cond) { - p_left[ileft++] = rid[i]; - } else { - p_right[iright++] = rid[i]; - } - } - - return { ileft, iright }; -} - -template -inline std::pair PartitionDenseRightDefaultKernel(const RowIdxType* rid, - const IdxType* idx, const IdxType offset, const int32_t split_cond, - const size_t istart, const size_t iend, RowIdxType* p_left, RowIdxType* p_right) { - size_t ileft = 0; - size_t iright = 0; - - const IdxType max_val = std::numeric_limits::max(); - - for (size_t i = istart; i < iend; i++) { - if (idx[rid[i]] == max_val || static_cast(idx[rid[i]] + offset) > split_cond) { - p_right[iright++] = rid[i]; - } else { - p_left[ileft++] = rid[i]; - } - } - return { ileft, iright }; -} - -template -inline std::pair PartitionSparseKernel(const RowIdxType* rowid, - const IdxType* idx, const int32_t split_cond, const size_t ibegin, - const size_t iend, RowIdxType* p_left, RowIdxType* p_right, - Column column, bool default_left) { - size_t ileft = 0; - size_t iright = 0; - - if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range - // search first nonzero row with index >= rowid[ibegin] - const size_t* p = std::lower_bound(column.GetRowData(), - column.GetRowData() + column.Size(), - rowid[ibegin]); - if (p != column.GetRowData() + column.Size() && *p <= rowid[iend - 1]) { - size_t cursor = p - column.GetRowData(); - - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowid[i]; - while (cursor < column.Size() - && column.GetRowIdx(cursor) < rid - && column.GetRowIdx(cursor) <= rowid[iend - 1]) { - ++cursor; - } - if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { - const uint32_t rbin = column.GetFeatureBinIdx(cursor); - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - p_left[ileft++] = rid; - } else { - p_right[iright++] = rid; - } - ++cursor; - } else { - // missing value - if (default_left) { - p_left[ileft++] = rid; - } else { - p_right[iright++] = rid; - } + const std::vector &gpair_h) { + builder_monitor_.Start("BuildLocalHistograms"); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + RegTree::Node &node = (*p_tree)[nid]; + if (rabit::IsDistributed()) { + if (node.IsRoot() || node.IsLeftChild()) { + hist_.AddHistRow(nid); + // in distributed setting, we always calculate from left child or root node + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); + if (!node.IsRoot()) { + nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid; } + (*sync_count)++; + (*starting_index) = std::min((*starting_index), nid); } - } else { // all rows in [ibegin, iend) have missing values - if (default_left) { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowid[i]; - p_left[ileft++] = rid; - } - } else { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowid[i]; - p_right[iright++] = rid; - } + } else { + if (!node.IsRoot() && node.IsLeftChild() && + (row_set_collection_[nid].Size() < + row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) { + hist_.AddHistRow(nid); + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); + nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid; + (*sync_count)++; + (*starting_index) = std::min((*starting_index), nid); + } else if (!node.IsRoot() && !node.IsLeftChild() && + (row_set_collection_[nid].Size() <= + row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) { + hist_.AddHistRow(nid); + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); + nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].LeftChild()] = nid; + (*sync_count)++; + (*starting_index) = std::min((*starting_index), nid); + } else if (node.IsRoot()) { + hist_.AddHistRow(nid); + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); + (*sync_count)++; + (*starting_index) = std::min((*starting_index), nid); } } } - return {ileft, iright}; -} - - -int32_t QuantileHistMaker::Builder::FindSplitCond(int32_t nid, - RegTree *p_tree, - const GHistIndexMatrix &gmat) { - bst_float left_leaf_weight = spliteval_->ComputeWeight(nid, - snode_[nid].best.left_sum) * param_.learning_rate; - bst_float right_leaf_weight = spliteval_->ComputeWeight(nid, - snode_[nid].best.right_sum) * param_.learning_rate; - p_tree->ExpandNode(nid, snode_[nid].best.SplitIndex(), snode_[nid].best.split_value, - snode_[nid].best.DefaultLeft(), snode_[nid].weight, left_leaf_weight, - right_leaf_weight, snode_[nid].best.loss_chg, snode_[nid].stats.sum_hess); - - RegTree::Node node = (*p_tree)[nid]; - // Categorize member rows - const bst_uint fid = node.SplitIndex(); - const bst_float split_pt = node.SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.Values()[i]) { - split_cond = static_cast(i); - } - } - return split_cond; + builder_monitor_.Stop("BuildLocalHistograms"); } -// split rows in each node to blocks of rows -// for future parallel execution -template -void QuantileHistMaker::Builder::CreateTasksForApplySplit( - const std::vector& nodes, - const GHistIndexMatrix &gmat, - RegTree *p_tree, - int *num_leaves, - const int depth, - const size_t block_size, - std::vector* tasks, - std::vector* nodes_bounds) { - size_t* buffer = buffer_for_partition_.data(); - size_t cur_buff_offset = 0; - - auto create_nodes = [&](int32_t this_nid) { - if (snode_[this_nid].best.loss_chg < kRtEps || - (param_.max_depth > 0 && depth == param_.max_depth) || - (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { - (*p_tree)[this_nid].SetLeaf(snode_[this_nid].weight * param_.learning_rate); - } else { - const size_t nrows = row_set_collection_[this_nid].Size(); - const size_t n_blocks = nrows / block_size + !!(nrows % block_size); - - nodes_bounds->emplace_back(this_nid, tasks->size(), tasks->size() + n_blocks); - - const int32_t split_cond = FindSplitCond(this_nid, p_tree, gmat); - - for (size_t i = 0; i < n_blocks; ++i) { - const size_t istart = i*block_size; - const size_t iend = (i == n_blocks-1) ? nrows : istart + block_size; - - TaskType task {this_nid, split_cond, n_blocks, i, istart, iend, nodes_bounds->size()-1, - buffer + cur_buff_offset, buffer + cur_buff_offset + (iend-istart), 0, 0, 0, 0}; - tasks->push_back(task); - cur_buff_offset += 2*(iend-istart); - } - } - }; - for (const auto& node : nodes) { - const int32_t nid = node.nid; - const int32_t sibling_nid = node.sibling_nid; - create_nodes(nid); - - if (sibling_nid > -1) { - create_nodes(sibling_nid); - } - } +void QuantileHistMaker::Builder::BuildNodeStats( + const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h) { + builder_monitor_.Start("BuildNodeStats"); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); + // add constraints + if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { + // it's a right child + auto parent_id = (*p_tree)[nid].Parent(); + auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); + auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); + spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id, + snode_[left_sibling_id].weight, snode_[nid].weight); + } + } + builder_monitor_.Stop("BuildNodeStats"); } -void QuantileHistMaker::Builder::CreateNewNodesBatch( - const std::vector& nodes, +void QuantileHistMaker::Builder::EvaluateSplits( const GHistIndexMatrix &gmat, const ColumnMatrix &column_matrix, DMatrix *p_fmat, @@ -322,367 +184,23 @@ void QuantileHistMaker::Builder::CreateNewNodesBatch( int depth, unsigned *timestamp, std::vector *temp_qexpand_depth) { - perf_monitor.TickStart(); - const size_t block_size = 2048; - - struct ApplySplitTaskInfo { - // input - int32_t nid; - int32_t split_cond; - size_t n_blocks_this_node; - size_t i_block_this_node; - size_t istart; - size_t iend; - size_t inode; - // result - size_t* left; - size_t* right; - size_t n_left; - size_t n_right; - size_t ileft; - size_t iright; - }; - - struct NodeBoundsInfo { - NodeBoundsInfo(int32_t nid, size_t begin, size_t end): - nid(nid), begin(begin), end(end) { - } - - int32_t nid; - size_t begin; - size_t end; - }; - - // create tasks for partition of row_set_collection_ - std::vector tasks; - std::vector nodes_bounds; - - // 1. Split row-indexes in each nodes to blocks of rows - CreateTasksForApplySplit(nodes, gmat, p_tree, num_leaves, - depth, block_size, &tasks, &nodes_bounds); - - // buffer to store # of rows in left part for each row-block - std::vector left_sizes; - left_sizes.reserve(nodes_bounds.size()); - const int size = tasks.size(); - - // execute tasks in parallel - #pragma omp parallel - { - // 2. For each block of rows: - // a) Write row-indexes which should come to the left child - to 1th buffer - // b) Write row-indexes which should come to the right child - to 2th buffer - // values in each buffer - sorted in original order - #pragma omp for - for (int32_t i = 0; i < size; ++i) { - const int32_t nid = tasks[i].nid; - const int32_t split_cond = tasks[i].split_cond; - const size_t istart = tasks[i].istart; - const size_t iend = tasks[i].iend; - - const bst_uint fid = (*p_tree)[nid].SplitIndex(); - const bool default_left = (*p_tree)[nid].DefaultLeft(); - const Column column = column_matrix.GetColumn(fid); - - const uint32_t* idx = column.GetIndex(); - const size_t* rid = row_set_collection_[nid].begin; - - if (column.GetType() == xgboost::common::kDenseColumn) { - if (default_left) { - auto res = PartitionDenseLeftDefaultKernel( - rid, idx, column.GetBaseIdx(), split_cond, istart, iend, - tasks[i].left, tasks[i].right); - tasks[i].n_left = res.first; - tasks[i].n_right = res.second; - } else { - auto res = PartitionDenseRightDefaultKernel( - rid, idx, column.GetBaseIdx(), split_cond, istart, iend, - tasks[i].left, tasks[i].right); - tasks[i].n_left = res.first; - tasks[i].n_right = res.second; - } - } else { - auto res = PartitionSparseKernel( - rid, idx, split_cond, istart, iend, tasks[i].left, tasks[i].right, column, default_left); - tasks[i].n_left = res.first; - tasks[i].n_right = res.second; - } - } - - // 3. For each node - find number of elements in left the part - #pragma omp single - { - for (auto& node : nodes_bounds) { - size_t n_left = 0; - size_t n_right = 0; - - for (size_t i = node.begin; i < node.end; ++i) { - tasks[i].ileft = n_left; - tasks[i].iright = n_right; - - n_left += tasks[i].n_left; - n_right += tasks[i].n_right; - } - left_sizes.push_back(n_left); - } - } - - // 4. Copy data from buffers to original row_set_collection_ - #pragma omp for - for (int32_t i = 0; i < size; ++i) { - const size_t node_idx = tasks[i].inode; - const int32_t nid = tasks[i].nid; - const size_t n_left = left_sizes[node_idx]; - - CHECK_LE(tasks[i].ileft + tasks[i].n_left, row_set_collection_[nid].Size()); - CHECK_LE(n_left + tasks[i].iright + tasks[i].n_right, row_set_collection_[nid].Size()); - - auto* rid = const_cast(row_set_collection_[nid].begin); - std::memcpy(rid + tasks[i].ileft, tasks[i].left, - tasks[i].n_left * sizeof(rid[0])); - std::memcpy(rid + n_left + tasks[i].iright, tasks[i].right, - tasks[i].n_right * sizeof(rid[0])); - } - } - - // register new nodes - for (size_t i = 0; i < nodes_bounds.size(); ++i) { - const int32_t nid = nodes_bounds[i].nid; - const size_t n_left = left_sizes[i]; - RegTree::Node node = (*p_tree)[nid]; - - const int32_t left_id = node.LeftChild(); - const int32_t right_id = node.RightChild(); - row_set_collection_.AddSplit(nid, n_left, left_id, right_id); - - if (rabit::IsDistributed() || - row_set_collection_[left_id].Size() < row_set_collection_[right_id].Size()) { - temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, nid, - depth + 1, 0.0, (*timestamp)++)); - } else { - temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id, nid, - depth + 1, 0.0, (*timestamp)++)); - } - } - perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::APPLY_SPLIT); -} - -std::tuple - QuantileHistMaker::Builder::GetHistBuffer( - std::vector* hist_is_init, std::vector* grad_stats, - size_t block_id, size_t nthread, size_t tid, - std::vector* data_hist, size_t hist_size) { - - const size_t n_hist_for_current_node = hist_is_init->size(); - const size_t hist_id = ((n_hist_for_current_node == nthread) ? tid : block_id); - - common::GradStatHist::GradType* local_data_hist = (*data_hist)[hist_id]; - if (!(*hist_is_init)[hist_id]) { - std::fill(local_data_hist, local_data_hist + hist_size, 0.0f); - (*hist_is_init)[hist_id] = true; - } - - return std::make_tuple(local_data_hist, &(*grad_stats)[hist_id]); -} - -void QuantileHistMaker::Builder::CreateTasksForBuildHist( - size_t block_size_rows, - size_t nthread, - const std::vector& nodes, - std::vector>* hist_buffers, - std::vector>* hist_is_init, - std::vector>* grad_stats, - std::vector* task_nid, - std::vector* task_node_idx, - std::vector* task_block_idx) { - size_t i_hist = 0; - - // prepare tasks for parallel execution - for (size_t i = 0; i < nodes.size(); ++i) { - const int32_t nid = nodes[i].nid; - const int32_t sibling_nid = nodes[i].sibling_nid; - hist_.AddHistRow(nid); - if (sibling_nid > -1) { - hist_.AddHistRow(sibling_nid); - } - const size_t nrows = row_set_collection_[nid].Size(); - const size_t n_local_blocks = nrows / block_size_rows + !!(nrows % block_size_rows); - const size_t n_local_histograms = std::min(nthread, n_local_blocks); - - task_nid->resize(task_nid->size() + n_local_blocks, nid); - for (size_t j = 0; j < n_local_blocks; ++j) { - task_node_idx->push_back(i); - task_block_idx->push_back(j); - } - - (*hist_buffers)[i].clear(); - for (size_t j = 0; j < n_local_histograms; j++) { - (*hist_buffers)[i].push_back( - reinterpret_cast(hist_buff_[i_hist++].data())); - } - (*hist_is_init)[i].clear(); - (*hist_is_init)[i].resize(n_local_histograms, false); - (*grad_stats)[i].resize(n_local_histograms); - } -} - -void QuantileHistMaker::Builder::BuildHistsBatch(const std::vector& nodes, - RegTree* p_tree, const GHistIndexMatrix &gmat, const std::vector& gpair, - std::vector>* hist_buffers, - std::vector>* hist_is_init) { - perf_monitor.TickStart(); - const size_t block_size_rows = 256; - const size_t nthread = static_cast(this->nthread_); - const size_t nbins = gmat.cut.Ptrs().back(); - const size_t hist_size = 2 * nbins; - - hist_buffers->resize(nodes.size()); - hist_is_init->resize(nodes.size()); - - // input data for tasks - std::vector task_nid; - std::vector task_node_idx; - std::vector task_block_idx; - - // result vector - std::vector> grad_stats(nodes.size()); - - // 1. Create tasks for hist construction by block of rows for each node - CreateTasksForBuildHist(block_size_rows, nthread, nodes, hist_buffers, hist_is_init, &grad_stats, - &task_nid, &task_node_idx, &task_block_idx); - int32_t n_hist_buidling_tasks = task_node_idx.size(); - - const GradientPair::ValueT* const pgh = - reinterpret_cast(gpair.data()); - - // 2. Build partial histograms for each node - #pragma omp parallel for schedule(static) - for (int32_t itask = 0; itask < n_hist_buidling_tasks; ++itask) { - const size_t tid = omp_get_thread_num(); - const int32_t nid = task_nid[itask]; - const int32_t block_id = task_block_idx[itask]; - // node_idx : location of node `nid` within the `nodes` list. In general, node_idx != nid - const int32_t node_idx = task_node_idx[itask]; - - common::GradStatHist::GradType* data_local_hist; - common::GradStatHist* grad_stat; // total gradient/hessian value for node `nid` - std::tie(data_local_hist, grad_stat) = GetHistBuffer(&(*hist_is_init)[node_idx], - &grad_stats[node_idx], block_id, nthread, tid, - &(*hist_buffers)[node_idx], hist_size); - - const size_t* row_ptr = gmat.row_ptr.data(); - const size_t* rid = row_set_collection_[nid].begin; - - const size_t nrows = row_set_collection_[nid].Size(); - const size_t istart = block_id * block_size_rows; - const size_t iend = (((block_id+1)*block_size_rows > nrows) ? nrows : istart + block_size_rows); - - // call hist building kernel depending on bin-matrix layout - if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { - common::BuildHistLocalDense(istart, iend, nrows, rid, gmat.index.data(), pgh, - row_ptr, data_local_hist, grad_stat); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); + if (snode_[nid].best.loss_chg < kRtEps || + (param_.max_depth > 0 && depth == param_.max_depth) || + (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { + (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); } else { - common::BuildHistLocalSparse(istart, iend, nrows, rid, gmat.index.data(), pgh, - row_ptr, data_local_hist, grad_stat); - } - } - - // 3. Merge grad stats for each node - // Sync histograms in case of distributed computation - SyncHistograms(p_tree, nodes, hist_buffers, hist_is_init, grad_stats); - - perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::BUILD_HIST); -} - -void QuantileHistMaker::Builder::SyncHistograms( - RegTree* p_tree, - const std::vector& nodes, - std::vector>* hist_buffers, - std::vector>* hist_is_init, - const std::vector>& grad_stats) { - if (rabit::IsDistributed()) { - const int size = nodes.size(); - #pragma omp parallel for // TODO(egorsmir): replace to n_features * nodes.size() - for (int i = 0; i < size; ++i) { - const int32_t nid = nodes[i].nid; - common::GradStatHist::GradType* hist_data = - reinterpret_cast(hist_[nid].data()); - - ReduceHistograms(hist_data, nullptr, nullptr, 0, hist_builder_.GetNumBins() * 2, i, - *hist_is_init, *hist_buffers); - } - - for (auto elem : nodes) { - this->histred_.Allreduce(hist_[elem.nid].data(), hist_builder_.GetNumBins()); - } - - // TODO(egorsmir): add parallel for - for (auto elem : nodes) { - if (elem.sibling_nid > -1) { - SubtractionTrick(hist_[elem.sibling_nid], hist_[elem.nid], - hist_[(*p_tree)[elem.sibling_nid].Parent()]); - } - } - } - - // merge grad stats - { - for (size_t inode = 0; inode < nodes.size(); ++inode) { - const int32_t nid = nodes[inode].nid; - - if (snode_.size() <= size_t(nid)) { - snode_.resize(nid + 1, NodeEntry(param_)); - } - - common::GradStatHist grad_stat; - for (size_t ihist = 0; ihist < (*hist_is_init)[inode].size(); ++ihist) { - if ((*hist_is_init)[inode][ihist]) { - grad_stat.Add(grad_stats[inode][ihist]); - } - } - this->histred_.Allreduce(&grad_stat, 1); - snode_[nid].stats = grad_stat.ToGradStat(); - - const int32_t sibling_nid = nodes[inode].sibling_nid; - if (sibling_nid > -1) { - if (snode_.size() <= size_t(sibling_nid)) { - snode_.resize(sibling_nid + 1, NodeEntry(param_)); - } - const int parent_id = (*p_tree)[nid].Parent(); - snode_[sibling_nid].stats.SetSubstract(snode_[parent_id].stats, snode_[nid].stats); - } - } - } -} - -// merge some block of partial histograms -void QuantileHistMaker::Builder::ReduceHistograms( - common::GradStatHist::GradType* hist_data, - common::GradStatHist::GradType* sibling_hist_data, - common::GradStatHist::GradType* parent_hist_data, - const size_t ibegin, - const size_t iend, - const size_t inode, - const std::vector>& hist_is_init, - const std::vector>& hist_buffers) { - bool is_init = false; - for (size_t ihist = 0; ihist < hist_is_init[inode].size(); ++ihist) { - common::GradStatHist::GradType* partial_data = hist_buffers[inode][ihist]; - if (hist_is_init[inode][ihist] && is_init) { - for (size_t i = ibegin; i < iend; ++i) { - hist_data[i] += partial_data[i]; - } - } else if (hist_is_init[inode][ihist]) { - for (size_t i = ibegin; i < iend; ++i) { - hist_data[i] = partial_data[i]; - } - is_init = true; - } - } - - if (sibling_hist_data) { - for (size_t i = ibegin; i < iend; ++i) { - sibling_hist_data[i] = parent_hist_data[i] - hist_data[i]; + this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); + int left_id = (*p_tree)[nid].LeftChild(); + int right_id = (*p_tree)[nid].RightChild(); + temp_qexpand_depth->push_back(ExpandEntry(left_id, + p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); + temp_qexpand_depth->push_back(ExpandEntry(right_id, + p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); + // - 1 parent + 2 new children + (*num_leaves)++; } } } @@ -691,34 +209,24 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, - DMatrix* p_fmat, - RegTree* p_tree, + DMatrix *p_fmat, + RegTree *p_tree, const std::vector &gpair_h) { unsigned timestamp = 0; int num_leaves = 0; // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway - qexpand_depth_wise_.emplace_back(0, -1, ROOT_PARENT_ID, p_tree->GetDepth(0), 0.0, timestamp++); + qexpand_depth_wise_.emplace_back(ExpandEntry(0, p_tree->GetDepth(0), 0.0, timestamp++)); ++num_leaves; - for (int depth = 0; depth < param_.max_depth + 1; depth++) { + int starting_index = std::numeric_limits::max(); + int sync_count = 0; std::vector temp_qexpand_depth; - - // buffer to store partial histograms - std::vector> hist_buffers; - // uint8_t is used instead of bool due to read/write - // to std::vector - thread unsafe - std::vector> hist_is_init; - - BuildHistsBatch(qexpand_depth_wise_, p_tree, gmat, gpair_h, - &hist_buffers, &hist_is_init); - BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, qexpand_depth_wise_); - EvaluateSplitsBatch(qexpand_depth_wise_, gmat, *p_fmat, hist_is_init, hist_buffers); - CreateNewNodesBatch(qexpand_depth_wise_, gmat, column_matrix, p_fmat, p_tree, - &num_leaves, depth, ×tamp, &temp_qexpand_depth); - - num_leaves += temp_qexpand_depth.size(); - + BuildLocalHistograms(&starting_index, &sync_count, gmat, gmatb, p_tree, gpair_h); + SyncHistograms(starting_index, sync_count, p_tree); + BuildNodeStats(gmat, p_fmat, p_tree, gpair_h); + EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, ×tamp, + &temp_qexpand_depth); // clean up qexpand_depth_wise_.clear(); nodes_for_subtraction_trick_.clear(); @@ -738,21 +246,18 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h) { + unsigned timestamp = 0; int num_leaves = 0; - std::vector> hist_buffers; - std::vector> hist_is_init; - for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { - std::vector nodes_to_build{ExpandEntry( - 0, -1, ROOT_PARENT_ID, p_tree->GetDepth(0), 0.0, timestamp++)}; + hist_.AddHistRow(nid); + BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true); - BuildHistsBatch(nodes_to_build, p_tree, gmat, gpair_h, &hist_buffers, &hist_is_init); - BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, nodes_to_build); - EvaluateSplitsBatch(nodes_to_build, gmat, *p_fmat, hist_is_init, hist_buffers); + this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); - qexpand_loss_guided_->push(ExpandEntry(nid, -1, -1, p_tree->GetDepth(nid), + this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); + qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid), snode_[nid].best.loss_chg, timestamp++)); ++num_leaves; @@ -760,29 +265,50 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( while (!qexpand_loss_guided_->empty()) { const ExpandEntry candidate = qexpand_loss_guided_->top(); - const int32_t nid = candidate.nid; + const int nid = candidate.nid; qexpand_loss_guided_->pop(); + if (candidate.loss_chg <= kRtEps + || (param_.max_depth > 0 && candidate.depth == param_.max_depth) + || (param_.max_leaves > 0 && num_leaves == param_.max_leaves) ) { + (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); + } else { + this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); - std::vector nodes_to_build{candidate}; - std::vector successors; + const int cleft = (*p_tree)[nid].LeftChild(); + const int cright = (*p_tree)[nid].RightChild(); + hist_.AddHistRow(cleft); + hist_.AddHistRow(cright); - CreateNewNodesBatch(nodes_to_build, gmat, column_matrix, p_fmat, p_tree, - &num_leaves, candidate.depth, ×tamp, &successors); + if (rabit::IsDistributed()) { + // in distributed mode, we need to keep consistent across workers + BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); + SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); + } else { + if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { + BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); + SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); + } else { + BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright], true); + SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); + } + } - if (!successors.empty()) { - BuildHistsBatch(successors, p_tree, gmat, gpair_h, &hist_buffers, &hist_is_init); - BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, successors); - EvaluateSplitsBatch(successors, gmat, *p_fmat, hist_is_init, hist_buffers); + this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree); + this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree); + bst_uint featureid = snode_[nid].best.SplitIndex(); + spliteval_->AddSplit(nid, cleft, cright, featureid, + snode_[cleft].weight, snode_[cright].weight); - const int32_t cleft = (*p_tree)[nid].LeftChild(); - const int32_t cright = (*p_tree)[nid].RightChild(); + this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree); + this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree); - qexpand_loss_guided_->push(ExpandEntry(cleft, -1, nid, p_tree->GetDepth(cleft), + qexpand_loss_guided_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft), snode_[cleft].best.loss_chg, timestamp++)); - qexpand_loss_guided_->push(ExpandEntry(cright, -1, nid, p_tree->GetDepth(cright), + qexpand_loss_guided_->push(ExpandEntry(cright, p_tree->GetDepth(cright), snode_[cright].best.loss_chg, timestamp++)); + ++num_leaves; // give two and take one, as parent is no longer a leaf } } @@ -794,14 +320,13 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { - perf_monitor.StartPerfMonitor(); + builder_monitor_.Start("Update"); const std::vector& gpair_h = gpair->ConstHostVector(); + spliteval_->Reset(); - perf_monitor.TickStart(); this->InitData(gmat, gpair_h, *p_fmat, *p_tree); - perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::INIT_DATA); if (param_.grow_policy == TrainParam::kLossGuide) { ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h); @@ -812,18 +337,17 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; - p_tree->Stat(nid).sum_hess = - static_cast(snode_[nid].stats.sum_hess); + p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); - perf_monitor.EndPerfMonitor(); + builder_monitor_.Stop("Update"); } bool QuantileHistMaker::Builder::UpdatePredictionCache( - const DMatrix* data, - HostDeviceVector* p_out_preds) { + const DMatrix* data, + HostDeviceVector* p_out_preds) { std::vector& out_preds = p_out_preds->HostVector(); // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in @@ -839,31 +363,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( CHECK_GT(out_preds.size(), 0U); - const size_t block_size = 2048; - const size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); - std::vector tasks_elem; - std::vector tasks_iblock; - std::vector tasks_nblock; - - for (size_t k = 0; k < n_nodes; ++k) { - const size_t nrows = row_set_collection_[k].Size(); - const size_t nblocks = nrows / block_size + !!(nrows % block_size); - - for (size_t i = 0; i < nblocks; ++i) { - tasks_elem.push_back(row_set_collection_[k]); - tasks_iblock.push_back(i); - tasks_nblock.push_back(nblocks); - } - } - -#pragma omp parallel for schedule(static) - for (omp_ulong k = 0; k < tasks_elem.size(); ++k) { - const RowSetCollection::Elem rowset = tasks_elem[k]; - if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) { - const size_t nrows = rowset.Size(); - const size_t iblock = tasks_iblock[k]; - const size_t nblocks = tasks_nblock[k]; - + for (const RowSetCollection::Elem rowset : row_set_collection_) { + if (rowset.begin != nullptr && rowset.end != nullptr) { int nid = rowset.node_id; bst_float leaf_value; // if a node is marked as deleted by the pruner, traverse upward to locate @@ -876,11 +377,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } leaf_value = (*p_last_tree_)[nid].LeafValue(); - const size_t istart = iblock*block_size; - const size_t iend = (iblock == nblocks-1) ? nrows : istart + block_size; - - for (size_t it = istart; it < iend; ++it) { - out_preds[rowset.begin[it]] += leaf_value; + for (const size_t* it = rowset.begin; it < rowset.end; ++it) { + out_preds[*it] += leaf_value; } } } @@ -901,6 +399,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " << "when grow_policy is depthwise."; } + builder_monitor_.Start("InitData"); const auto& info = fmat.Info(); { @@ -911,16 +410,12 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // initialize histogram collection uint32_t nbins = gmat.cut.Ptrs().back(); hist_.Init(nbins); - hist_buff_.Init(nbins); // initialize histogram builder - #pragma omp parallel +#pragma omp parallel { this->nthread_ = omp_get_num_threads(); } - - const auto nthread = static_cast(this->nthread_); - row_split_tloc_.resize(nthread); hist_builder_.Init(this->nthread_, nbins); CHECK_EQ(info.root_index_.size(), 0U); @@ -962,7 +457,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } bool has_neg_hess = false; - for (int32_t tid = 0; tid < this->nthread_; ++tid) { + for (size_t tid = 0; tid < this->nthread_; ++tid) { if (p_buff[tid]) { has_neg_hess = true; } @@ -990,8 +485,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } } } + row_set_collection_.Init(); - buffer_for_partition_.reserve(2 * info.num_row_); { /* determine layout of data */ @@ -1054,123 +549,290 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, qexpand_depth_wise_.clear(); } } + builder_monitor_.Stop("InitData"); } -void QuantileHistMaker::Builder::EvaluateSplitsBatch( - const std::vector& nodes, - const GHistIndexMatrix& gmat, - const DMatrix& fmat, - const std::vector>& hist_is_init, - const std::vector>& hist_buffers) { - perf_monitor.TickStart(); +void QuantileHistMaker::Builder::EvaluateSplit(const int nid, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const DMatrix& fmat, + const RegTree& tree) { + builder_monitor_.Start("EvaluateSplit"); + // start enumeration const MetaInfo& info = fmat.Info(); - // prepare tasks - std::vector> tasks; - for (size_t i = 0; i < nodes.size(); ++i) { - auto p_feature_set = column_sampler_.GetFeatureSet(nodes[i].depth); - - const auto& feature_set = p_feature_set->HostVector(); - const auto nfeature = static_cast(feature_set.size()); - for (size_t j = 0; j < nfeature; ++j) { - tasks.emplace_back(i, feature_set[j]); - } - } + auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); + const auto& feature_set = p_feature_set->HostVector(); + const auto nfeature = static_cast(feature_set.size()); + const auto nthread = static_cast(this->nthread_); + best_split_tloc_.resize(nthread); +#pragma omp parallel for schedule(static) num_threads(nthread) + for (bst_omp_uint tid = 0; tid < nthread; ++tid) { + best_split_tloc_[tid] = snode_[nid].best; + } + GHistRow node_hist = hist[nid]; + +#pragma omp parallel for schedule(dynamic) num_threads(nthread) + for (bst_omp_uint i = 0; i < nfeature; ++i) { // NOLINT(*) + const auto feature_id = static_cast(feature_set[i]); + const auto tid = static_cast(omp_get_thread_num()); + const auto node_id = static_cast(nid); + // Narrow search space by dropping features that are not feasible under the + // given set of constraints (e.g. feature interaction constraints) + if (spliteval_->CheckFeatureConstraint(node_id, feature_id)) { + this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + } + } + for (unsigned tid = 0; tid < nthread; ++tid) { + snode_[nid].best.Update(best_split_tloc_[tid]); + } + builder_monitor_.Stop("EvaluateSplit"); +} - // rabit::IsDistributed is not thread-safe - auto isDistributed = rabit::IsDistributed(); - // partial results - std::vector> splits(tasks.size()); - // parallel enumeration - #pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < tasks.size(); ++i) { - // node_idx : offset within `nodes` list - const int32_t node_idx = tasks[i].first; - const size_t fid = tasks[i].second; - const int32_t nid = nodes[node_idx].nid; // usually node_idx != nid - const int32_t sibling_nid = nodes[node_idx].sibling_nid; - const int32_t parent_nid = nodes[node_idx].parent_nid; - - // reduce needed part of a hist here to have it in cache before enumeration - if (!isDistributed) { - auto hist_data = reinterpret_cast(hist_[nid].data()); - auto sibling_hist_data = sibling_nid > -1 ? - reinterpret_cast( - hist_[sibling_nid].data()) : nullptr; - auto parent_hist_data = sibling_nid > -1 ? - reinterpret_cast( - hist_[parent_nid].data()) : nullptr; - - const std::vector& cut_ptr = gmat.cut.Ptrs(); - const size_t ibegin = 2 * cut_ptr[fid]; - const size_t iend = 2 * cut_ptr[fid + 1]; - ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx, - hist_is_init, hist_buffers); +void QuantileHistMaker::Builder::ApplySplit(int nid, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree) { + builder_monitor_.Start("ApplySplit"); + // TODO(hcho3): support feature sampling by levels + + /* 1. Create child nodes */ + NodeEntry& e = snode_[nid]; + bst_float left_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = + spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); + + /* 2. Categorize member rows */ + const auto nthread = static_cast(this->nthread_); + row_split_tloc_.resize(nthread); + for (bst_omp_uint i = 0; i < nthread; ++i) { + row_split_tloc_[i].left.clear(); + row_split_tloc_[i].right.clear(); + } + const bool default_left = (*p_tree)[nid].DefaultLeft(); + const bst_uint fid = (*p_tree)[nid].SplitIndex(); + const bst_float split_pt = (*p_tree)[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut.Values()[i]) { + split_cond = static_cast(i); } + } - if (spliteval_->CheckFeatureConstraint(nid, fid)) { - auto& snode = snode_[nid]; - const bool compute_backward = this->EnumerateSplit(+1, gmat, hist_[nid], snode, - info, &splits[i].first, fid, nid); - - // Sometimes, we don't need to enumerate backward because forward and backward - // enumeration will give same loss values. This is the case if the particular feature - // column contains no missing values. So enumerate backward only if it's necessary. - if (compute_backward) { - this->EnumerateSplit(-1, gmat, hist_[nid], snode, info, - &splits[i].first, fid, nid); - } - } + const auto& rowset = row_set_collection_[nid]; - if (sibling_nid > -1 && spliteval_->CheckFeatureConstraint(sibling_nid, fid)) { - auto& snode = snode_[sibling_nid]; + Column column = column_matrix.GetColumn(fid); + if (column.GetType() == xgboost::common::kDenseColumn) { + ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond, + default_left); + } else { + ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound, + upper_bound, split_cond, default_left); + } - const bool compute_backward = this->EnumerateSplit(+1, gmat, hist_[sibling_nid], snode, - info, &splits[i].second, fid, sibling_nid); + row_set_collection_.AddSplit( + nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild()); + builder_monitor_.Stop("ApplySplit"); +} - if (compute_backward) { - this->EnumerateSplit(-1, gmat, hist_[sibling_nid], snode, info, - &splits[i].second, fid, sibling_nid); +void QuantileHistMaker::Builder::ApplySplitDenseData( + const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + const Column& column, + bst_int split_cond, + bool default_left) { + std::vector& row_split_tloc = *p_row_split_tloc; + constexpr int kUnroll = 8; // loop unrolling factor + const size_t nrows = rowset.end - rowset.begin; + const size_t rest = nrows % kUnroll; + +#pragma omp parallel for num_threads(nthread_) schedule(static) + for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) { + const bst_uint tid = omp_get_thread_num(); + auto& left = row_split_tloc[tid].left; + auto& right = row_split_tloc[tid].right; + size_t rid[kUnroll]; + uint32_t rbin[kUnroll]; + for (int k = 0; k < kUnroll; ++k) { + rid[k] = rowset.begin[i + k]; + } + for (int k = 0; k < kUnroll; ++k) { + rbin[k] = column.GetFeatureBinIdx(rid[k]); + } + for (int k = 0; k < kUnroll; ++k) { // NOLINT + if (rbin[k] == std::numeric_limits::max()) { // missing value + if (default_left) { + left.push_back(rid[k]); + } else { + right.push_back(rid[k]); + } + } else { + if (static_cast(rbin[k] + column.GetBaseIdx()) <= split_cond) { + left.push_back(rid[k]); + } else { + right.push_back(rid[k]); + } } } } - - // choice of the best splits - for (size_t i = 0; i < splits.size(); ++i) { - const int32_t node_idx = tasks[i].first; - const int32_t nid = nodes[node_idx].nid; - const int32_t sibling_nid = nodes[node_idx].sibling_nid; - snode_[nid].best.Update(splits[i].first); - if (sibling_nid > -1) { - snode_[sibling_nid].best.Update(splits[i].second); + for (size_t i = nrows - rest; i < nrows; ++i) { + auto& left = row_split_tloc[nthread_-1].left; + auto& right = row_split_tloc[nthread_-1].right; + const size_t rid = rowset.begin[i]; + const uint32_t rbin = column.GetFeatureBinIdx(rid); + if (rbin == std::numeric_limits::max()) { // missing value + if (default_left) { + left.push_back(rid); + } else { + right.push_back(rid); + } + } else { + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { + left.push_back(rid); + } else { + right.push_back(rid); + } } } +} + +void QuantileHistMaker::Builder::ApplySplitSparseData( + const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + const Column& column, + bst_uint lower_bound, + bst_uint upper_bound, + bst_int split_cond, + bool default_left) { + std::vector& row_split_tloc = *p_row_split_tloc; + const size_t nrows = rowset.end - rowset.begin; + +#pragma omp parallel num_threads(nthread_) + { + const auto tid = static_cast(omp_get_thread_num()); + const size_t ibegin = tid * nrows / nthread_; + const size_t iend = (tid + 1) * nrows / nthread_; + if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range + // search first nonzero row with index >= rowset[ibegin] + const size_t* p = std::lower_bound(column.GetRowData(), + column.GetRowData() + column.Size(), + rowset.begin[ibegin]); + + auto& left = row_split_tloc[tid].left; + auto& right = row_split_tloc[tid].right; + if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) { + size_t cursor = p - column.GetRowData(); - perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::EVALUATE_SPLIT); + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; + while (cursor < column.Size() + && column.GetRowIdx(cursor) < rid + && column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) { + ++cursor; + } + if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { + const uint32_t rbin = column.GetFeatureBinIdx(cursor); + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { + left.push_back(rid); + } else { + right.push_back(rid); + } + ++cursor; + } else { + // missing value + if (default_left) { + left.push_back(rid); + } else { + right.push_back(rid); + } + } + } + } else { // all rows in [ibegin, iend) have missing values + if (default_left) { + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; + left.push_back(rid); + } + } else { + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowset.begin[i]; + right.push_back(rid); + } + } + } + } + } } void QuantileHistMaker::Builder::InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, - RegTree* tree, - QuantileHistMaker::NodeEntry* snode, - int32_t parentid) { + const RegTree& tree) { + builder_monitor_.Start("InitNewNode"); + { + snode_.resize(tree.param.num_nodes, NodeEntry(param_)); + } + + { + auto& stats = snode_[nid].stats; + GHistRow hist = hist_[nid]; + if (tree[nid].IsRoot()) { + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + const std::vector& row_ptr = gmat.cut.Ptrs(); + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + auto begin = hist.data(); + for (uint32_t i = ibegin; i < iend; ++i) { + const GradStats et = begin[i]; + stats.Add(et.sum_grad, et.sum_hess); + } + } else { + const RowSetCollection::Elem e = row_set_collection_[nid]; + for (const size_t* it = e.begin; it < e.end; ++it) { + stats.Add(gpair[*it]); + } + } + histred_.Allreduce(&snode_[nid].stats, 1); + } else { + int parent_id = tree[nid].Parent(); + if (tree[nid].IsLeftChild()) { + snode_[nid].stats = snode_[parent_id].best.left_sum; + } else { + snode_[nid].stats = snode_[parent_id].best.right_sum; + } + } + } + // calculating the weights { - snode->weight = static_cast( - spliteval_->ComputeWeight(parentid, snode->stats)); - snode->root_gain = static_cast( - spliteval_->ComputeScore(parentid, snode->stats, - snode->weight)); + bst_uint parentid = tree[nid].Parent(); + snode_[nid].weight = static_cast( + spliteval_->ComputeWeight(parentid, snode_[nid].stats)); + snode_[nid].root_gain = static_cast( + spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); } + builder_monitor_.Stop("InitNewNode"); } // enumerate the split values of specific feature -// d_step: +1 or -1, indicating direction at which we scan candidate thresholds in order -// fid: feature for which we seek to pick best threshold -// Returns false if we don't need to enumerate in opposite direction. -// This is the case if the particular feature (fid) column contains no missing values. -bool QuantileHistMaker::Builder::EnumerateSplit(int d_step, +void QuantileHistMaker::Builder::EnumerateSplit(int d_step, const GHistIndexMatrix& gmat, const GHistRow& hist, const NodeEntry& snode, @@ -1209,54 +871,39 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step, iend = static_cast(cut_ptr[fid]) - 1; } - if (d_step == 1) { - for (int32_t i = ibegin; i < iend; i++) { - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.sum_hess >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.sum_hess >= param_.min_child_weight) { - bst_float loss_chg = static_cast(spliteval_->ComputeSplitScore(nodeID, - fid, e, c) - snode.root_gain); - bst_float split_pt = cut_val[i]; - best.Update(loss_chg, fid, split_pt, false, e, c); - } - } - } - p_best->Update(best); - - if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { - return false; - } - } else { - for (int32_t i = ibegin; i != iend; i--) { - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.sum_hess >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.sum_hess >= param_.min_child_weight) { - bst_float split_pt; + for (int32_t i = ibegin; i != iend; i += d_step) { + // start working + // try to find a split + e.Add(hist[i].GetGrad(), hist[i].GetHess()); + if (e.sum_hess >= param_.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.sum_hess >= param_.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + // forward enumeration: split at right bound of each bin + loss_chg = static_cast( + spliteval_->ComputeSplitScore(nodeID, fid, e, c) - + snode.root_gain); + split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); + } else { // backward enumeration: split at left bound of each bin - bst_float loss_chg = static_cast( + loss_chg = static_cast( spliteval_->ComputeSplitScore(nodeID, fid, c, e) - snode.root_gain); - if (i == imin) { // for leftmost bin, left bound is the smallest feature value split_pt = gmat.cut.MinValues()[fid]; } else { split_pt = cut_val[i - 1]; } - best.Update(loss_chg, fid, split_pt, true, c, e); + best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); } } } - p_best->Update(best); - - if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { - return false; - } } - - return true; + p_best->Update(best); } XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 6ac3949eb1b5..2224971130a4 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -1,8 +1,8 @@ /*! - * Copyright 2017-2019 by Contributors + * Copyright 2017-2018 by Contributors * \file updater_quantile_hist.h * \brief use quantized feature values to construct a tree - * \author Philip Cho, Tianqi Chen, Egor Smirnov + * \author Philip Cho, Tianqi Chen */ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ @@ -18,19 +18,51 @@ #include #include #include -#include #include "./param.h" #include "./split_evaluator.h" #include "../common/random.h" +#include "../common/timer.h" #include "../common/hist_util.h" #include "../common/row_set.h" #include "../common/column_matrix.h" namespace xgboost { -namespace common { - struct GradStatHist; -} + +/*! + * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. + */ +template +class MemStackAllocator { + public: + explicit MemStackAllocator(size_t required_size): required_size_(required_size) { + } + + T* Get() { + if (!ptr_) { + if (MaxStackSize >= required_size_) { + ptr_ = stack_mem_; + } else { + ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + do_free_ = true; + } + } + + return ptr_; + } + + ~MemStackAllocator() { + if (do_free_) free(ptr_); + } + + + private: + T* ptr_ = nullptr; + bool do_free_ = false; + size_t required_size_; + T stack_mem_[MaxStackSize]; +}; + namespace tree { using xgboost::common::GHistIndexMatrix; @@ -71,7 +103,6 @@ class QuantileHistMaker: public TreeUpdater { bool is_gmat_initialized_; // data structure - public: struct NodeEntry { /*! \brief statics for node entry */ GradStats stats; @@ -83,8 +114,7 @@ class QuantileHistMaker: public TreeUpdater { SplitEntry best; // constructor explicit NodeEntry(const TrainParam& param) - : root_gain(0.0f), weight(0.0f) { - } + : root_gain(0.0f), weight(0.0f) {} }; // actual builder that runs the algorithm @@ -94,8 +124,11 @@ class QuantileHistMaker: public TreeUpdater { explicit Builder(const TrainParam& param, std::unique_ptr pruner, std::unique_ptr spliteval) - : param_(param), pruner_(std::move(pruner)), spliteval_(std::move(spliteval)), - p_last_tree_(nullptr), p_last_fmat_(nullptr) { } + : param_(param), pruner_(std::move(pruner)), + spliteval_(std::move(spliteval)), p_last_tree_(nullptr), + p_last_fmat_(nullptr) { + builder_monitor_.Init("Quantile::Builder"); + } // update one tree, growing virtual void Update(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, @@ -104,104 +137,42 @@ class QuantileHistMaker: public TreeUpdater { DMatrix* p_fmat, RegTree* p_tree); + inline void BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist, + bool sync_hist) { + builder_monitor_.Start("BuildHist"); + if (param_.enable_feature_grouping > 0) { + hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); + } else { + hist_builder_.BuildHist(gpair, row_indices, gmat, hist); + } + if (sync_hist) { + this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins()); + } + builder_monitor_.Stop("BuildHist"); + } + + inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { + builder_monitor_.Start("SubtractionTrick"); + hist_builder_.SubtractionTrick(self, sibling, parent); + builder_monitor_.Stop("SubtractionTrick"); + } + bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* p_out_preds); - std::tuple - GetHistBuffer(std::vector* hist_is_init, - std::vector* grad_stats, size_t block_id, size_t nthread, - size_t tid, std::vector* data_hist, size_t hist_size); - protected: /* tree growing policies */ struct ExpandEntry { int nid; - int sibling_nid; - int parent_nid; int depth; bst_float loss_chg; unsigned timestamp; - ExpandEntry(int nid, int sibling_nid, int parent_nid, int depth, bst_float loss_chg, - unsigned tstmp) : nid(nid), sibling_nid(sibling_nid), parent_nid(parent_nid), - depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} - }; - - struct TreeGrowingPerfMonitor { - enum timer_name {INIT_DATA, INIT_NEW_NODE, BUILD_HIST, EVALUATE_SPLIT, APPLY_SPLIT}; - - double global_start; - - // performance counters - double tstart; - double time_init_data = 0; - double time_init_new_node = 0; - double time_build_hist = 0; - double time_evaluate_split = 0; - double time_apply_split = 0; - - inline void StartPerfMonitor() { - global_start = dmlc::GetTime(); - } - - inline void EndPerfMonitor() { - CHECK_GT(global_start, 0); - double total_time = dmlc::GetTime() - global_start; - LOG(INFO) << "\nInitData: " - << std::fixed << std::setw(6) << std::setprecision(4) << time_init_data - << " (" << std::fixed << std::setw(5) << std::setprecision(2) - << time_init_data / total_time * 100 << "%)\n" - << "InitNewNode: " - << std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node - << " (" << std::fixed << std::setw(5) << std::setprecision(2) - << time_init_new_node / total_time * 100 << "%)\n" - << "BuildHist: " - << std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist - << " (" << std::fixed << std::setw(5) << std::setprecision(2) - << time_build_hist / total_time * 100 << "%)\n" - << "EvaluateSplit: " - << std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split - << " (" << std::fixed << std::setw(5) << std::setprecision(2) - << time_evaluate_split / total_time * 100 << "%)\n" - << "ApplySplit: " - << std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split - << " (" << std::fixed << std::setw(5) << std::setprecision(2) - << time_apply_split / total_time * 100 << "%)\n" - << "========================================\n" - << "Total: " - << std::fixed << std::setw(6) << std::setprecision(4) << total_time << std::endl; - // clear performance counters - time_init_data = 0; - time_init_new_node = 0; - time_build_hist = 0; - time_evaluate_split = 0; - time_apply_split = 0; - } - - inline void TickStart() { - tstart = dmlc::GetTime(); - } - - inline void UpdatePerfTimer(const timer_name &timer_name) { - // CHECK_GT(tstart, 0); // TODO Fix - switch (timer_name) { - case INIT_DATA: - time_init_data += dmlc::GetTime() - tstart; - break; - case INIT_NEW_NODE: - time_init_new_node += dmlc::GetTime() - tstart; - break; - case BUILD_HIST: - time_build_hist += dmlc::GetTime() - tstart; - break; - case EVALUATE_SPLIT: - time_evaluate_split += dmlc::GetTime() - tstart; - break; - case APPLY_SPLIT: - time_apply_split += dmlc::GetTime() - tstart; - break; - } - tstart = -1; - } + ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp) + : nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} }; // initialize temp data structure @@ -210,16 +181,43 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); + void EvaluateSplit(const int nid, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const DMatrix& fmat, + const RegTree& tree); + + void ApplySplit(int nid, + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree); + + void ApplySplitDenseData(const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + const Column& column, + bst_int split_cond, + bool default_left); + + void ApplySplitSparseData(const RowSetCollection::Elem rowset, + const GHistIndexMatrix& gmat, + std::vector* p_row_split_tloc, + const Column& column, + bst_uint lower_bound, + bst_uint upper_bound, + bst_int split_cond, + bool default_left); + void InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, - RegTree* tree, - QuantileHistMaker::NodeEntry* snode, - int32_t parentid); + const RegTree& tree); // enumerate the split values of specific feature - bool EnumerateSplit(int d_step, + void EnumerateSplit(int d_step, const GHistIndexMatrix& gmat, const GHistRow& hist, const NodeEntry& snode, @@ -228,36 +226,37 @@ class QuantileHistMaker: public TreeUpdater { bst_uint fid, bst_uint nodeID); - void EvaluateSplitsBatch(const std::vector& nodes, - const GHistIndexMatrix& gmat, - const DMatrix& fmat, - const std::vector>& hist_is_init, - const std::vector>& hist_buffers); - - void ReduceHistograms( - common::GradStatHist::GradType* hist_data, - common::GradStatHist::GradType* sibling_hist_data, - common::GradStatHist::GradType* parent_hist_data, - const size_t ibegin, - const size_t iend, - const size_t inode, - const std::vector>& hist_is_init, - const std::vector>& hist_buffers); - - void SyncHistograms( - RegTree* p_tree, - const std::vector& nodes, - std::vector>* hist_buffers, - std::vector>* hist_is_init, - const std::vector>& grad_stats); - - void ExpandWithDepthWise(const GHistIndexMatrix &gmat, + void ExpandWithDepthWise(const GHistIndexMatrix &gmat, + const GHistIndexBlockMatrix &gmatb, + const ColumnMatrix &column_matrix, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h); + + void BuildLocalHistograms(int *starting_index, + int *sync_count, + const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h); + void SyncHistograms(int starting_index, + int sync_count, + RegTree *p_tree); + + void BuildNodeStats(const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h); + + void EvaluateSplits(const GHistIndexMatrix &gmat, + const ColumnMatrix &column_matrix, + DMatrix *p_fmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth); void ExpandWithLossGuide(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, @@ -266,62 +265,6 @@ class QuantileHistMaker: public TreeUpdater { RegTree* p_tree, const std::vector& gpair_h); - - void BuildHistsBatch(const std::vector& nodes, RegTree* tree, - const GHistIndexMatrix &gmat, const std::vector& gpair, - std::vector>* hist_buffers, - std::vector>* hist_is_init); - - void BuildNodeStat(const GHistIndexMatrix &gmat, - DMatrix *p_fmat, - RegTree *p_tree, - const std::vector &gpair_h, - int32_t nid); - - void BuildNodeStatBatch( - const GHistIndexMatrix &gmat, - DMatrix *p_fmat, - RegTree *p_tree, - const std::vector &gpair_h, - const std::vector& nodes); - - int32_t FindSplitCond(int32_t nid, - RegTree *p_tree, - const GHistIndexMatrix &gmat); - - void CreateNewNodesBatch( - const std::vector& nodes, - const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, - int *num_leaves, - int depth, - unsigned *timestamp, - std::vector *temp_qexpand_depth); - - template - void CreateTasksForApplySplit( - const std::vector& nodes, - const GHistIndexMatrix &gmat, - RegTree *p_tree, - int *num_leaves, - const int depth, - const size_t block_size, - std::vector* tasks, - std::vector* nodes_bounds); - - void CreateTasksForBuildHist( - size_t block_size_rows, - size_t nthread, - const std::vector& nodes, - std::vector>* hist_buffers, - std::vector>* hist_is_init, - std::vector>* grad_stats, - std::vector* task_nid, - std::vector* task_node_idx, - std::vector* task_block_idx); - inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { if (lhs.loss_chg == rhs.loss_chg) { return lhs.timestamp > rhs.timestamp; // favor small timestamp @@ -330,8 +273,6 @@ class QuantileHistMaker: public TreeUpdater { } } - HistCollection hist_buff_; - // --data fields-- const TrainParam& param_; // number of omp thread used during training @@ -342,7 +283,6 @@ class QuantileHistMaker: public TreeUpdater { // the temp space for split std::vector row_split_tloc_; std::vector best_split_tloc_; - std::vector buffer_for_partition_; /*! \brief TreeNode Data: statistics for each constructed node */ std::vector snode_; /*! \brief culmulative histogram of gradients. */ @@ -374,8 +314,8 @@ class QuantileHistMaker: public TreeUpdater { enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; - TreeGrowingPerfMonitor perf_monitor; - rabit::Reducer histred_; + common::Monitor builder_monitor_; + rabit::Reducer histred_; }; std::unique_ptr builder_; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 3437fea0faef..9420893d1c58 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -103,14 +103,8 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::InitData(gmat, gpair, fmat, tree); GHistIndexBlockMatrix dummy; hist_.AddHistRow(nid); - - std::vector> hist_buffers; - std::vector> hist_is_init; - std::vector nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)}; - BuildHistsBatch(nodes, const_cast(&tree), gmat, gpair, &hist_buffers, &hist_is_init); - RealImpl::InitNewNode(nid, gmat, gpair, fmat, - const_cast(&tree), &snode_[0], tree[0].Parent()); - EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers); + BuildHist(gpair, row_set_collection_[nid], + gmat, dummy, hist_[nid], false); // Check if number of histogram bins is correct ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back()); @@ -151,13 +145,10 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::InitData(gmat, row_gpairs, *(*dmat), tree); hist_.AddHistRow(0); - std::vector nodes = {ExpandEntry(0, -1, -1, tree.GetDepth(0), 0.0, 0)}; - std::vector> hist_buffers; - std::vector> hist_is_init; - BuildHistsBatch(nodes, const_cast(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init); - RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), - const_cast(&tree), &snode_[0], tree[0].Parent()); - EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); + BuildHist(row_gpairs, row_set_collection_[0], + gmat, quantile_index_block, hist_[0], false); + + RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree); /* Compute correct split (best_split) using the computed histogram */ const size_t num_row = dmat->get()->Info().num_row_; @@ -208,7 +199,6 @@ class QuantileHistMock : public QuantileHistMaker { const auto split_gain = evaluator->ComputeSplitScore(0, fid, GradStats(left_sum), GradStats(right_sum)); - if (split_gain > best_split_gain) { best_split_gain = split_gain; best_split_feature = fid; @@ -218,8 +208,7 @@ class QuantileHistMock : public QuantileHistMaker { } /* Now compare against result given by EvaluateSplit() */ - EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); - + RealImpl::EvaluateSplit(0, gmat, hist_, *(*dmat), tree); ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); @@ -310,7 +299,7 @@ TEST(Updater, QuantileHist_EvalSplits) { std::vector> cfg {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, {"split_evaluator", "elastic_net"}, - {"reg_lambda", "1.0f"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, + {"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, {"min_child_weight", "0"}}; QuantileHistMock maker(cfg); maker.TestEvaluateSplit();