Skip to content

Commit

Permalink
Optimize ‘hist’ for multi-core CPU (dmlc#4529)
Browse files Browse the repository at this point in the history
* Initial performance optimizations for xgboost

* remove includes

* revert float->double

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* Check existence of _mm_prefetch and __builtin_prefetch

* Fix lint

* optimizations for CPU

* appling comments in review

* add some comments, code refactoring

* fixing issues in CI

* adding runtime checks

* remove 1 extra check

* remove extra checks in BuildHist

* remove checks

* add debug info

* added debug info

* revert changes

* added comments

* Apply suggestions from code review

Co-Authored-By: Philip Hyunsu Cho <[email protected]>

* apply review comments

* Remove unused function CreateNewNodes()

* Add descriptive comment on node_idx variable in QuantileHistMaker::Builder::BuildHistsBatch()
  • Loading branch information
SmirnovEgorRu authored and hcho3 committed Jun 27, 2019
1 parent abffbe0 commit 4d6590b
Show file tree
Hide file tree
Showing 9 changed files with 1,341 additions and 817 deletions.
11 changes: 7 additions & 4 deletions src/common/column_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_COMMON_COLUMN_MATRIX_H_

#include <dmlc/timer.h>
#include <limits>
#include <vector>
#include "hist_util.h"


namespace xgboost {
namespace common {

Expand Down Expand Up @@ -51,6 +51,10 @@ class Column {
}
const size_t* GetRowData() const { return row_ind_; }

const uint32_t* GetIndex() const {
return index_;
}

private:
ColumnType type_;
const uint32_t* index_;
Expand Down Expand Up @@ -80,7 +84,7 @@ class ColumnMatrix {
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);

uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (bst_uint fid = 0; fid < nfeature; ++fid) {
for (int32_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
}

Expand Down Expand Up @@ -113,13 +117,12 @@ class ColumnMatrix {
boundary_[fid].index_end = accum_index_;
boundary_[fid].row_ind_end = accum_row_ind_;
}

index_.resize(boundary_[nfeature - 1].index_end);
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);

// store least bin id for each feature
index_base_.resize(nfeature);
for (bst_uint fid = 0; fid < nfeature; ++fid) {
for (int32_t fid = 0; fid < nfeature; ++fid) {
index_base_[fid] = gmat.cut.row_ptr[fid];
}

Expand Down
258 changes: 113 additions & 145 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/*!
* Copyright 2017-2019 by Contributors
* \file hist_util.h
* \file hist_util.cc
*/
#include "./hist_util.h"
#include <dmlc/timer.h>
#include <rabit/rabit.h>
#include <dmlc/omp.h>
#include <numeric>
#include <vector>

#include "./random.h"
#include "./column_matrix.h"
#include "./hist_util.h"
#include "./quantile.h"
#include "./../tree/updater_quantile_hist.h"

Expand Down Expand Up @@ -178,7 +178,7 @@ uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {

void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
cut.Init(p_fmat, max_num_bins);
const size_t nthread = omp_get_max_threads();
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.row_ptr.back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
Expand Down Expand Up @@ -260,8 +260,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
}

#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) {
for (size_t tid = 0; tid < nthread; ++tid) {
for (int32_t idx = 0; idx < int32_t(nbins); ++idx) {
for (int32_t tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
}
}
Expand Down Expand Up @@ -411,7 +411,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
for (auto fid : group) {
nnz += feature_nnz[fid];
}
double nnz_rate = static_cast<double>(nnz) / nrow;
float nnz_rate = static_cast<float>(nnz) / nrow;
// take apart small sparse group, due it will not gain on speed
if (nnz_rate <= param.sparse_threshold) {
for (auto fid : group) {
Expand Down Expand Up @@ -496,177 +496,145 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
}
}

void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist) {
const size_t nthread = static_cast<size_t>(this->nthread_);
data_.resize(nbins_ * nthread_);

const size_t* rid = row_indices.begin;
const size_t nrows = row_indices.Size();
const uint32_t* index = gmat.index.data();
const size_t* row_ptr = gmat.row_ptr.data();
const float* pgh = reinterpret_cast<const float*>(gpair.data());

double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data());

const size_t block_size = 512;
size_t n_blocks = nrows/block_size;
n_blocks += !!(nrows - n_blocks*block_size);

const size_t nthread_to_process = std::min(nthread, n_blocks);
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));
// used when data layout is kDenseDataZeroBased or kDenseDataOneBased
// it means that "row_ptr" is not needed for hist computations
void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
GradStatHist grad_stat; // make local var to prevent false sharing

const size_t n_features = row_ptr[rid[istart]+1] - row_ptr[rid[istart]];
const size_t cache_line_size = 64;
const size_t prefetch_step = cache_line_size / sizeof(*index);
const size_t prefetch_offset = 10;

size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;

#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_));
// if read each row in some block of bin-matrix - it's dense block
// and we dont need SW prefetch in this case
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart - 1);

if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
thread_init_[tid] = true;
}

const size_t istart = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
if (iend < nrows - no_prefetch_size && !denseBlock) {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t icol_start = rid[i] * n_features;
const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features;
const size_t idx_gh = 2*rid[i];

if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);

for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features;
j += prefetch_step) {
PREFETCH_READ_T0(index + j);
}

for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];

for (size_t j = icol_start; j < icol_start + n_features; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
} else {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = rid[i] * n_features;
const size_t idx_gh = 2*rid[i];
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];

for (size_t j = icol_start; j < icol_start + n_features; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}
grad_stat_global->Add(grad_stat);
}

if (nthread_to_process > 1) {
const size_t size = (2*nbins_);
const size_t block_size = 1024;
size_t n_blocks = size/block_size;
n_blocks += !!(size - n_blocks*block_size);
// used when data layout is kSparseData
// it means that "row_ptr" is needed for hist computations
void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
GradStatHist grad_stat; // make local var to prevent false sharing

size_t n_worked_bins = 0;
for (size_t i = 0; i < nthread_to_process; ++i) {
if (thread_init_[i]) {
thread_init_[n_worked_bins++] = i;
}
}
const size_t cache_line_size = 64;
const size_t prefetch_step = cache_line_size / sizeof(index[0]);
const size_t prefetch_offset = 10;

#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;

const size_t bin = 2 * thread_init_[0] * nbins_;
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
// if read each row in some block of bin-matrix - it's dense block
// and we dont need SW prefetch in this case
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart);

for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
for (size_t i = istart; i < iend; i++) {
hist_data[i] += data[bin + i];
}
}
}
}
}
if (iend < nrows - no_prefetch_size && !denseBlock) {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t idx_gh = 2*rid[i];

void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
GHistRow hist) {
constexpr int kUnroll = 8; // loop unrolling factor
const size_t nblock = gmatb.GetNumBlock();
const size_t nrows = row_indices.end - row_indices.begin;
const size_t rest = nrows % kUnroll;

#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_hist = hist.data();

#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
auto gmat = gmatb[bid];

for (size_t i = 0; i < nrows - rest; i += kUnroll) {
size_t rid[kUnroll];
size_t ibegin[kUnroll];
size_t iend[kUnroll];
GradientPair stat[kUnroll];

for (int k = 0; k < kUnroll; ++k) {
rid[k] = row_indices.begin[i + k];
ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1];
stat[k] = gpair[rid[k]];
const size_t icol_start10 = row_ptr[rid[i+prefetch_offset]];
const size_t icol_end10 = row_ptr[rid[i+prefetch_offset]+1];

PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);

for (size_t j = icol_start10; j < icol_end10; j+=prefetch_step) {
PREFETCH_READ_T0(index + j);
}
for (int k = 0; k < kUnroll; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const uint32_t bin = gmat.index[j];
p_hist[bin].Add(stat[k]);
}

grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];

for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
for (size_t i = nrows - rest; i < nrows; ++i) {
const size_t rid = row_indices.begin[i];
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const GradientPair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) {
const uint32_t bin = gmat.index[j];
p_hist[bin].Add(stat);
} else {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t idx_gh = 2*rid[i];

grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];

for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}
grad_stat_global->Add(grad_stat);
}

void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
const uint32_t nbins = static_cast<bst_omp_uint>(nbins_);
constexpr int kUnroll = 8; // loop unrolling factor
const uint32_t rest = nbins % kUnroll;

#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_self = self.data();
tree::GradStats* p_sibling = sibling.data();
tree::GradStats* p_parent = parent.data();

#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0;
bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) {
tree::GradStats pb[kUnroll];
tree::GradStats sb[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
pb[k] = p_parent[bin_id + k];
}
for (int k = 0; k < kUnroll; ++k) {
sb[k] = p_sibling[bin_id + k];
}
for (int k = 0; k < kUnroll; ++k) {
p_self[bin_id + k].SetSubstract(pb[k], sb[k]);
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
GradStatHist* p_self = self.data();
GradStatHist* p_sibling = sibling.data();
GradStatHist* p_parent = parent.data();

const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);

const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);

#pragma omp parallel for
for (int iblock = 0; iblock < n_blocks; ++iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
}
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
}

} // namespace common
Expand Down
Loading

0 comments on commit 4d6590b

Please sign in to comment.