From d20ceac77c157f16579087dd007ead35b1fb261b Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 26 Feb 2020 13:03:57 +0800 Subject: [PATCH] decouple bagging with num_threads (#2804) * fix bagging * fixed cpplint issues * updated docs Co-authored-by: Nikita Titov --- docs/FAQ.rst | 8 ++ docs/Parameters.rst | 2 +- include/LightGBM/config.h | 2 +- include/LightGBM/utils/threading.h | 137 ++++++++++++++++++++++++++- src/boosting/gbdt.cpp | 144 +++++++++++------------------ src/boosting/gbdt.h | 40 ++------ src/boosting/goss.hpp | 30 +++--- src/treelearner/data_partition.hpp | 95 +++---------------- 8 files changed, 226 insertions(+), 232 deletions(-) diff --git a/docs/FAQ.rst b/docs/FAQ.rst index ce88bf042acd..d8ec3f592bcc 100644 --- a/docs/FAQ.rst +++ b/docs/FAQ.rst @@ -1,3 +1,6 @@ +.. role:: raw-html(raw) + :format: html + LightGBM FAQ ############ @@ -82,8 +85,13 @@ You may also use the CPU version. 6. Bagging is not reproducible when changing the number of threads. ------------------------------------------------------------------- +:raw-html:`` LightGBM bagging is multithreaded, so its output depends on the number of threads used. There is `no workaround currently `__. +:raw-html:`` + +Starting from `#2804 `__ bagging result doesn't depend on the number of threads. +So this issue should be solved in the latest version. 7. I tried to use Random Forest mode, and LightGBM crashes! ----------------------------------------------------------- diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 9115d750f2f7..d62f1b90c3ab 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -566,7 +566,7 @@ Dataset Parameters - ``data_random_seed`` :raw-html:`🔗︎`, default = ``1``, type = int, aliases: ``data_seed`` - - random seed for data partition in parallel learning (excluding the ``feature_parallel`` mode) + - random seed for sampling data to construct histogram bins - ``is_enable_sparse`` :raw-html:`🔗︎`, default = ``true``, type = bool, aliases: ``is_sparse``, ``enable_sparse``, ``sparse`` diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 900e9ac27b74..a425c995f780 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -529,7 +529,7 @@ struct Config { int bin_construct_sample_cnt = 200000; // alias = data_seed - // desc = random seed for data partition in parallel learning (excluding the ``feature_parallel`` mode) + // desc = random seed for sampling data to construct histogram bins int data_random_seed = 1; // alias = is_sparse, enable_sparse, sparse diff --git a/include/LightGBM/utils/threading.h b/include/LightGBM/utils/threading.h index 909ab74d8cd5..c4cb57a49d73 100644 --- a/include/LightGBM/utils/threading.h +++ b/include/LightGBM/utils/threading.h @@ -1,12 +1,16 @@ /*! * Copyright (c) 2016 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the project root for license information. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. */ #ifndef LIGHTGBM_UTILS_THREADING_H_ #define LIGHTGBM_UTILS_THREADING_H_ +#include +#include #include +#include #include #include @@ -37,6 +41,23 @@ class Threading { *block_size = cnt; } } + template + static inline void BlockInfoForceSize(int num_threads, INDEX_T cnt, + INDEX_T min_cnt_per_block, + int* out_nblock, INDEX_T* block_size) { + *out_nblock = std::min( + num_threads, + static_cast((cnt + min_cnt_per_block - 1) / min_cnt_per_block)); + if (*out_nblock > 1) { + *block_size = (cnt + (*out_nblock) - 1) / (*out_nblock); + // force the block size to the times of min_cnt_per_block + *block_size = (*block_size + min_cnt_per_block - 1) / min_cnt_per_block * + min_cnt_per_block; + } else { + *block_size = cnt; + } + } + template static inline int For( INDEX_T start, INDEX_T end, INDEX_T min_block_size, @@ -58,6 +79,116 @@ class Threading { } }; -} // namespace LightGBM +template +class ParallelPartitionRunner { + public: + ParallelPartitionRunner(INDEX_T num_data, INDEX_T min_block_size) + : min_block_size_(min_block_size) { + num_threads_ = 1; +#pragma omp parallel +#pragma omp master + { num_threads_ = omp_get_num_threads(); } + left_.resize(num_data); + if (TWO_BUFFER) { + right_.resize(num_data); + } + offsets_.resize(num_threads_); + left_cnts_.resize(num_threads_); + right_cnts_.resize(num_threads_); + left_write_pos_.resize(num_threads_); + right_write_pos_.resize(num_threads_); + } + + ~ParallelPartitionRunner() {} + + void ReSize(INDEX_T num_data) { + left_.resize(num_data); + if (TWO_BUFFER) { + right_.resize(num_data); + } + } + + template + INDEX_T Run( + INDEX_T cnt, + const std::function& func, + INDEX_T* out) { + int nblock = 1; + INDEX_T inner_size = cnt; + if (FORCE_SIZE) { + Threading::BlockInfoForceSize(num_threads_, cnt, min_block_size_, + &nblock, &inner_size); + } else { + Threading::BlockInfo(num_threads_, cnt, min_block_size_, &nblock, + &inner_size); + } + + OMP_INIT_EX(); +#pragma omp parallel for schedule(static, 1) + for (int i = 0; i < nblock; ++i) { + OMP_LOOP_EX_BEGIN(); + INDEX_T cur_start = i * inner_size; + INDEX_T cur_cnt = std::min(inner_size, cnt - cur_start); + offsets_[i] = cur_start; + if (cur_cnt <= 0) { + left_cnts_[i] = 0; + right_cnts_[i] = 0; + continue; + } + auto left_ptr = left_.data() + cur_start; + INDEX_T* right_ptr = nullptr; + if (TWO_BUFFER) { + right_ptr = right_.data() + cur_start; + } + // split data inner, reduce the times of function called + INDEX_T cur_left_count = + func(i, cur_start, cur_cnt, left_ptr, right_ptr); + if (!TWO_BUFFER) { + // reverse for one buffer + std::reverse(left_ptr + cur_left_count, left_ptr + cur_cnt); + } + left_cnts_[i] = cur_left_count; + right_cnts_[i] = cur_cnt - cur_left_count; + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + + left_write_pos_[0] = 0; + right_write_pos_[0] = 0; + for (int i = 1; i < nblock; ++i) { + left_write_pos_[i] = left_write_pos_[i - 1] + left_cnts_[i - 1]; + right_write_pos_[i] = right_write_pos_[i - 1] + right_cnts_[i - 1]; + } + data_size_t left_cnt = left_write_pos_[nblock - 1] + left_cnts_[nblock - 1]; + + auto right_start = out + left_cnt; +#pragma omp parallel for schedule(static) + for (int i = 0; i < nblock; ++i) { + std::copy_n(left_.data() + offsets_[i], left_cnts_[i], + out + left_write_pos_[i]); + if (TWO_BUFFER) { + std::copy_n(right_.data() + offsets_[i], right_cnts_[i], + right_start + right_write_pos_[i]); + } else { + std::copy_n(left_.data() + offsets_[i] + left_cnts_[i], right_cnts_[i], + right_start + right_write_pos_[i]); + } + } + return left_cnt; + } + + private: + int num_threads_; + INDEX_T min_block_size_; + std::vector left_; + std::vector right_; + std::vector offsets_; + std::vector left_cnts_; + std::vector right_cnts_; + std::vector left_write_pos_; + std::vector right_write_pos_; +}; + +} // namespace LightGBM -#endif // LightGBM_UTILS_THREADING_H_ +#endif // LightGBM_UTILS_THREADING_H_ diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index f1072cac7117..039465f78c90 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -18,24 +18,21 @@ namespace LightGBM { -GBDT::GBDT() : iter_(0), -train_data_(nullptr), -objective_function_(nullptr), -early_stopping_round_(0), -es_first_metric_only_(false), -max_feature_idx_(0), -num_tree_per_iteration_(1), -num_class_(1), -num_iteration_for_pred_(0), -shrinkage_rate_(0.1f), -num_init_iteration_(0), -need_re_bagging_(false), -balanced_bagging_(false) { - #pragma omp parallel - #pragma omp master - { - num_threads_ = omp_get_num_threads(); - } +GBDT::GBDT() + : iter_(0), + train_data_(nullptr), + objective_function_(nullptr), + early_stopping_round_(0), + es_first_metric_only_(false), + max_feature_idx_(0), + num_tree_per_iteration_(1), + num_class_(1), + num_iteration_for_pred_(0), + shrinkage_rate_(0.1f), + num_init_iteration_(0), + need_re_bagging_(false), + balanced_bagging_(false), + bagging_runner_(0, bagging_rand_block_) { average_output_ = false; tree_learner_ = nullptr; } @@ -164,53 +161,50 @@ void GBDT::Boosting() { GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data()); } -data_size_t GBDT::BaggingHelper(Random* cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) { +data_size_t GBDT::BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) { if (cnt <= 0) { return 0; } - data_size_t bag_data_cnt = static_cast(config_->bagging_fraction * cnt); data_size_t cur_left_cnt = 0; - data_size_t cur_right_cnt = 0; - auto right_buffer = buffer + bag_data_cnt; + data_size_t cur_right_pos = cnt; // random bagging, minimal unit is one record for (data_size_t i = 0; i < cnt; ++i) { - float prob = (bag_data_cnt - cur_left_cnt) / static_cast(cnt - i); - if (cur_rand->NextFloat() < prob) { - buffer[cur_left_cnt++] = start + i; + auto cur_idx = start + i; + if (bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < config_->bagging_fraction) { + buffer[cur_left_cnt++] = cur_idx; } else { - right_buffer[cur_right_cnt++] = start + i; + buffer[--cur_right_pos] = cur_idx; } } - CHECK(cur_left_cnt == bag_data_cnt); return cur_left_cnt; } -data_size_t GBDT::BalancedBaggingHelper(Random* cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer) { +data_size_t GBDT::BalancedBaggingHelper(data_size_t start, data_size_t cnt, + data_size_t* buffer) { if (cnt <= 0) { return 0; } auto label_ptr = train_data_->metadata().label(); data_size_t cur_left_cnt = 0; - data_size_t cur_right_pos = cnt - 1; - // from right to left - auto right_buffer = buffer; + data_size_t cur_right_pos = cnt; // random bagging, minimal unit is one record for (data_size_t i = 0; i < cnt; ++i) { + auto cur_idx = start + i; bool is_pos = label_ptr[start + i] > 0; bool is_in_bag = false; if (is_pos) { - is_in_bag = cur_rand->NextFloat() < config_->pos_bagging_fraction; + is_in_bag = bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < + config_->pos_bagging_fraction; } else { - is_in_bag = cur_rand->NextFloat() < config_->neg_bagging_fraction; + is_in_bag = bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < + config_->neg_bagging_fraction; } if (is_in_bag) { - buffer[cur_left_cnt++] = start + i; + buffer[cur_left_cnt++] = cur_idx; } else { - right_buffer[cur_right_pos--] = start + i; + buffer[--cur_right_pos] = cur_idx; } } - // reverse right buffer - std::reverse(buffer + cur_left_cnt, buffer + cnt); return cur_left_cnt; } @@ -220,54 +214,20 @@ void GBDT::Bagging(int iter) { if ((bag_data_cnt_ < num_data_ && iter % config_->bagging_freq == 0) || need_re_bagging_) { need_re_bagging_ = false; - int n_block = Threading::For( - 0, num_data_, 1024, - [this, iter](int i, data_size_t cur_start, data_size_t cur_end) { - data_size_t cur_cnt = cur_end - cur_start; - if (cur_cnt <= 0) { - left_cnts_buf_[i] = 0; - right_cnts_buf_[i] = 0; + auto left_cnt = bagging_runner_.Run( + num_data_, + [=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left, + data_size_t*) { + data_size_t cur_left_count = 0; + if (balanced_bagging_) { + cur_left_count = + BalancedBaggingHelper(cur_start, cur_cnt, left); } else { - Random cur_rand(config_->bagging_seed + iter * num_threads_ + i); - data_size_t cur_left_count = 0; - if (balanced_bagging_) { - cur_left_count = - BalancedBaggingHelper(&cur_rand, cur_start, cur_cnt, - tmp_indices_.data() + cur_start); - } else { - cur_left_count = BaggingHelper(&cur_rand, cur_start, cur_cnt, - tmp_indices_.data() + cur_start); - } - offsets_buf_[i] = cur_start; - left_cnts_buf_[i] = cur_left_count; - right_cnts_buf_[i] = cur_cnt - cur_left_count; + cur_left_count = BaggingHelper(cur_start, cur_cnt, left); } - }); - data_size_t left_cnt = 0; - left_write_pos_buf_[0] = 0; - right_write_pos_buf_[0] = 0; - for (int i = 1; i < n_block; ++i) { - left_write_pos_buf_[i] = - left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1]; - right_write_pos_buf_[i] = - right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1]; - } - left_cnt = left_write_pos_buf_[n_block - 1] + left_cnts_buf_[n_block - 1]; - -#pragma omp parallel for schedule(static, 1) - for (int i = 0; i < n_block; ++i) { - if (left_cnts_buf_[i] > 0) { - std::memcpy(bag_data_indices_.data() + left_write_pos_buf_[i], - tmp_indices_.data() + offsets_buf_[i], - left_cnts_buf_[i] * sizeof(data_size_t)); - } - if (right_cnts_buf_[i] > 0) { - std::memcpy( - bag_data_indices_.data() + left_cnt + right_write_pos_buf_[i], - tmp_indices_.data() + offsets_buf_[i] + left_cnts_buf_[i], - right_cnts_buf_[i] * sizeof(data_size_t)); - } - } + return cur_left_count; + }, + bag_data_indices_.data()); bag_data_cnt_ = left_cnt; Log::Debug("Re-bagging, using %d data to train", bag_data_cnt_); // set bagging data to tree learner @@ -780,15 +740,15 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { bag_data_cnt_ = static_cast(config->bagging_fraction * num_data_); } bag_data_indices_.resize(num_data_); - tmp_indices_.resize(num_data_); - - offsets_buf_.resize(num_threads_); - left_cnts_buf_.resize(num_threads_); - right_cnts_buf_.resize(num_threads_); - left_write_pos_buf_.resize(num_threads_); - right_write_pos_buf_.resize(num_threads_); + bagging_runner_.ReSize(num_data_); + bagging_rands_.clear(); + for (int i = 0; + i < (num_data_ + bagging_rand_block_ - 1) / bagging_rand_block_; ++i) { + bagging_rands_.emplace_back(config_->bagging_seed + i); + } - double average_bag_rate = (bag_data_cnt_ / num_data_) / config->bagging_freq; + double average_bag_rate = + (static_cast(bag_data_cnt_) / num_data_) / config->bagging_freq; is_use_subset_ = false; const int group_threshold_usesubset = 100; if (tree_learner_->IsHistColWise() && average_bag_rate <= 0.5 @@ -813,7 +773,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) { } else { bag_data_cnt_ = num_data_; bag_data_indices_.clear(); - tmp_indices_.clear(); + bagging_runner_.ReSize(0); is_use_subset_ = false; } } diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index ec78485dac91..c4e452a69d19 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -390,24 +391,11 @@ class GBDT : public GBDTBase { */ virtual void Bagging(int iter); - /*! - * \brief Helper function for bagging, used for multi-threading optimization - * \param start start indice of bagging - * \param cnt count - * \param buffer output buffer - * \return count of left size - */ - data_size_t BaggingHelper(Random* cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer); + virtual data_size_t BaggingHelper(data_size_t start, data_size_t cnt, + data_size_t* buffer); - - /*! - * \brief Helper function for bagging, used for multi-threading optimization, balanced sampling - * \param start start indice of bagging - * \param cnt count - * \param buffer output buffer - * \return count of left size - */ - data_size_t BalancedBaggingHelper(Random* cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer); + data_size_t BalancedBaggingHelper(data_size_t start, data_size_t cnt, + data_size_t* buffer); /*! * \brief calculate the object function @@ -476,8 +464,6 @@ class GBDT : public GBDTBase { std::vector> bag_data_indices_; /*! \brief Number of in-bag data */ data_size_t bag_data_cnt_; - /*! \brief Store the indices of in-bag data */ - std::vector tmp_indices_; /*! \brief Number of training data */ data_size_t num_data_; /*! \brief Number of trees per iterations */ @@ -495,18 +481,6 @@ class GBDT : public GBDTBase { /*! \brief Feature names */ std::vector feature_names_; std::vector feature_infos_; - /*! \brief number of threads */ - int num_threads_; - /*! \brief Buffer for multi-threading bagging */ - std::vector offsets_buf_; - /*! \brief Buffer for multi-threading bagging */ - std::vector left_cnts_buf_; - /*! \brief Buffer for multi-threading bagging */ - std::vector right_cnts_buf_; - /*! \brief Buffer for multi-threading bagging */ - std::vector left_write_pos_buf_; - /*! \brief Buffer for multi-threading bagging */ - std::vector right_write_pos_buf_; std::unique_ptr tmp_subset_; bool is_use_subset_; std::vector class_need_train_; @@ -517,7 +491,9 @@ class GBDT : public GBDTBase { bool balanced_bagging_; std::string loaded_parameter_; std::vector monotone_constraints_; - + const int bagging_rand_block_ = 1024; + std::vector bagging_rands_; + ParallelPartitionRunner bagging_runner_; Json forced_splits_json_; }; diff --git a/src/boosting/goss.hpp b/src/boosting/goss.hpp index 1d4b07b21977..75b602f20c96 100644 --- a/src/boosting/goss.hpp +++ b/src/boosting/goss.hpp @@ -57,15 +57,9 @@ class GOSS: public GBDT { Log::Fatal("Cannot use bagging in GOSS"); } Log::Info("Using GOSS"); - + balanced_bagging_ = false; bag_data_indices_.resize(num_data_); - tmp_indices_.resize(num_data_); - tmp_indice_right_.resize(num_data_); - offsets_buf_.resize(num_threads_); - left_cnts_buf_.resize(num_threads_); - right_cnts_buf_.resize(num_threads_); - left_write_pos_buf_.resize(num_threads_); - right_write_pos_buf_.resize(num_threads_); + bagging_runner_.ReSize(num_data_); is_use_subset_ = false; if (config_->top_rate + config_->other_rate <= 0.5) { @@ -79,7 +73,7 @@ class GOSS: public GBDT { bag_data_cnt_ = num_data_; } - data_size_t BaggingHelper(Random* cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) { + data_size_t BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) override { if (cnt <= 0) { return 0; } @@ -98,31 +92,32 @@ class GOSS: public GBDT { score_t multiply = static_cast(cnt - top_k) / other_k; data_size_t cur_left_cnt = 0; - data_size_t cur_right_cnt = 0; + data_size_t cur_right_pos = cnt; data_size_t big_weight_cnt = 0; for (data_size_t i = 0; i < cnt; ++i) { + auto cur_idx = start + i; score_t grad = 0.0f; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { - size_t idx = static_cast(cur_tree_id) * num_data_ + start + i; + size_t idx = static_cast(cur_tree_id) * num_data_ + cur_idx; grad += std::fabs(gradients_[idx] * hessians_[idx]); } if (grad >= threshold) { - buffer[cur_left_cnt++] = start + i; + buffer[cur_left_cnt++] = cur_idx; ++big_weight_cnt; } else { data_size_t sampled = cur_left_cnt - big_weight_cnt; data_size_t rest_need = other_k - sampled; data_size_t rest_all = (cnt - i) - (top_k - big_weight_cnt); double prob = (rest_need) / static_cast(rest_all); - if (cur_rand->NextFloat() < prob) { - buffer[cur_left_cnt++] = start + i; + if (bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < prob) { + buffer[cur_left_cnt++] = cur_idx; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { - size_t idx = static_cast(cur_tree_id) * num_data_ + start + i; + size_t idx = static_cast(cur_tree_id) * num_data_ + cur_idx; gradients_[idx] *= multiply; hessians_[idx] *= multiply; } } else { - buffer_right[cur_right_cnt++] = start + i; + buffer[--cur_right_pos] = cur_idx; } } } @@ -135,9 +130,6 @@ class GOSS: public GBDT { if (iter < static_cast(1.0f / config_->learning_rate)) { return; } GBDT::Bagging(iter); } - - private: - std::vector tmp_indice_right_; }; } // namespace LightGBM diff --git a/src/treelearner/data_partition.hpp b/src/treelearner/data_partition.hpp index e9a6f43f146e..a386dbd0a8e9 100644 --- a/src/treelearner/data_partition.hpp +++ b/src/treelearner/data_partition.hpp @@ -21,23 +21,11 @@ namespace LightGBM { class DataPartition { public: DataPartition(data_size_t num_data, int num_leaves) - :num_data_(num_data), num_leaves_(num_leaves) { + : num_data_(num_data), num_leaves_(num_leaves), runner_(num_data, 512) { leaf_begin_.resize(num_leaves_); leaf_count_.resize(num_leaves_); indices_.resize(num_data_); - temp_left_indices_.resize(num_data_); - temp_right_indices_.resize(num_data_); used_data_indices_ = nullptr; - #pragma omp parallel - #pragma omp master - { - num_threads_ = omp_get_num_threads(); - } - offsets_buf_.resize(num_threads_); - left_cnts_buf_.resize(num_threads_); - right_cnts_buf_.resize(num_threads_); - left_write_pos_buf_.resize(num_threads_); - right_write_pos_buf_.resize(num_threads_); } void ResetLeaves(int num_leaves) { @@ -49,8 +37,7 @@ class DataPartition { void ResetNumData(int num_data) { num_data_ = num_data; indices_.resize(num_data_); - temp_left_indices_.resize(num_data_); - temp_right_indices_.resize(num_data_); + runner_.ReSize(num_data_); } ~DataPartition() { @@ -118,63 +105,18 @@ class DataPartition { // get leaf boundary const data_size_t begin = leaf_begin_[leaf]; const data_size_t cnt = leaf_count_[leaf]; - - int nblock = 1; - data_size_t inner_size = cnt; - Threading::BlockInfo(num_threads_, cnt, 512, &nblock, - &inner_size); auto left_start = indices_.data() + begin; - global_timer.Start("DataPartition::Split.MT"); - // split data multi-threading - OMP_INIT_EX(); -#pragma omp parallel for schedule(static, 1) - for (int i = 0; i < nblock; ++i) { - OMP_LOOP_EX_BEGIN(); - data_size_t cur_start = i * inner_size; - data_size_t cur_cnt = std::min(inner_size, cnt - cur_start); - if (cur_cnt <= 0) { - left_cnts_buf_[i] = 0; - right_cnts_buf_[i] = 0; - continue; - } - // split data inner, reduce the times of function called - data_size_t cur_left_count = - dataset->Split(feature, threshold, num_threshold, default_left, - left_start + cur_start, cur_cnt, - temp_left_indices_.data() + cur_start, - temp_right_indices_.data() + cur_start); - offsets_buf_[i] = cur_start; - left_cnts_buf_[i] = cur_left_count; - right_cnts_buf_[i] = cur_cnt - cur_left_count; - OMP_LOOP_EX_END(); - } - OMP_THROW_EX(); - global_timer.Stop("DataPartition::Split.MT"); - global_timer.Start("DataPartition::Split.Merge"); - left_write_pos_buf_[0] = 0; - right_write_pos_buf_[0] = 0; - for (int i = 1; i < nblock; ++i) { - left_write_pos_buf_[i] = - left_write_pos_buf_[i - 1] + left_cnts_buf_[i - 1]; - right_write_pos_buf_[i] = - right_write_pos_buf_[i - 1] + right_cnts_buf_[i - 1]; - } - data_size_t left_cnt = - left_write_pos_buf_[nblock - 1] + left_cnts_buf_[nblock - 1]; - - auto right_start = left_start + left_cnt; -#pragma omp parallel for schedule(static) - for (int i = 0; i < nblock; ++i) { - std::copy_n(temp_left_indices_.data() + offsets_buf_[i], - left_cnts_buf_[i], left_start + left_write_pos_buf_[i]); - std::copy_n(temp_right_indices_.data() + offsets_buf_[i], - right_cnts_buf_[i], right_start + right_write_pos_buf_[i]); - } - // update leaf boundary + auto left_cnt = runner_.Run( + cnt, + [=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left, + data_size_t* right) { + return dataset->Split(feature, threshold, num_threshold, default_left, + left_start + cur_start, cur_cnt, left, right); + }, + left_start); leaf_count_[leaf] = left_cnt; leaf_begin_[right_leaf] = left_cnt + begin; leaf_count_[right_leaf] = cnt - left_cnt; - global_timer.Stop("DataPartition::Split.Merge"); } /*! @@ -217,26 +159,11 @@ class DataPartition { std::vector leaf_count_; /*! \brief Store all data's indices, order by leaf[data_in_leaf0,..,data_leaf1,..] */ std::vector> indices_; - /*! \brief team indices buffer for split */ - std::vector> temp_left_indices_; - /*! \brief team indices buffer for split */ - std::vector> temp_right_indices_; /*! \brief used data indices, used for bagging */ const data_size_t* used_data_indices_; /*! \brief used data count, used for bagging */ data_size_t used_data_count_; - /*! \brief number of threads */ - int num_threads_; - /*! \brief Buffer for multi-threading data partition, used to store offset for different threads */ - std::vector offsets_buf_; - /*! \brief Buffer for multi-threading data partition, used to store left count after split for different threads */ - std::vector left_cnts_buf_; - /*! \brief Buffer for multi-threading data partition, used to store right count after split for different threads */ - std::vector right_cnts_buf_; - /*! \brief Buffer for multi-threading data partition, used to store write position of left leaf for different threads */ - std::vector left_write_pos_buf_; - /*! \brief Buffer for multi-threading data partition, used to store write position of right leaf for different threads */ - std::vector right_write_pos_buf_; + ParallelPartitionRunner runner_; }; } // namespace LightGBM