From 380a7a03e255cb1e6c652e080b7ee7669be21305 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 19 Jul 2023 13:41:47 +0800 Subject: [PATCH 1/2] Implement sketching with Hessian on GPU. --- include/xgboost/data.h | 12 +- include/xgboost/host_device_vector.h | 8 +- src/common/hist_util.cc | 10 +- src/common/hist_util.cu | 263 +++++++++++++------------ src/common/hist_util.cuh | 44 +++-- src/common/hist_util.h | 2 +- src/common/host_device_vector.cc | 3 + src/common/host_device_vector.cu | 5 + src/data/ellpack_page.cu | 2 +- src/data/gradient_index.cc | 2 +- src/data/gradient_index.h | 2 +- src/data/sparse_page_dmatrix.cu | 4 +- tests/cpp/common/test_hist_util.cu | 282 +++++++++++++++++++++------ tests/cpp/common/test_quantile.cu | 15 +- 14 files changed, 432 insertions(+), 222 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 472ca43b3c3a..eae2f612bc45 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -185,10 +185,10 @@ class MetaInfo { return data_split_mode == DataSplitMode::kRow; } - /*! \brief Whether the data is split column-wise. */ - bool IsColumnSplit() const { - return data_split_mode == DataSplitMode::kCol; - } + /** @brief Whether the data is split column-wise. */ + bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; } + /** @brief Whether this is a learning to rank data. */ + bool IsRanking() const { return !group_ptr_.empty(); } /*! * \brief A convenient method to check if we are doing vertical federated learning, which requires @@ -249,7 +249,7 @@ struct BatchParam { /** * \brief Hessian, used for sketching with future approx implementation. */ - common::Span hess; + common::Span hess; /** * \brief Whether should we force DMatrix to regenerate the batch. Only used for * GHistIndex. @@ -279,7 +279,7 @@ struct BatchParam { * Get batch with sketch weighted by hessian. The batch will be regenerated if the * span is changed, so caller should keep the span for each iteration. */ - BatchParam(bst_bin_t max_bin, common::Span hessian, bool regenerate) + BatchParam(bst_bin_t max_bin, common::Span hessian, bool regenerate) : max_bin{max_bin}, hess{hessian}, regen{regenerate} {} [[nodiscard]] bool ParamNotEqual(BatchParam const& other) const { diff --git a/include/xgboost/host_device_vector.h b/include/xgboost/host_device_vector.h index b9fb151047c6..b221d72067d1 100644 --- a/include/xgboost/host_device_vector.h +++ b/include/xgboost/host_device_vector.h @@ -49,11 +49,12 @@ #ifndef XGBOOST_HOST_DEVICE_VECTOR_H_ #define XGBOOST_HOST_DEVICE_VECTOR_H_ +#include // for DeviceOrd +#include // for Span + #include -#include #include - -#include "span.h" +#include namespace xgboost { @@ -133,6 +134,7 @@ class HostDeviceVector { GPUAccess DeviceAccess() const; void SetDevice(int device) const; + void SetDevice(DeviceOrd device) const; void Resize(size_t new_size, T v = T()); diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c9b50792d073..4e12ebc4cc85 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -8,12 +8,12 @@ #include -#include "../common/common.h" -#include "column_matrix.h" +#include "../data/adapter.h" // for SparsePageAdapterBatch +#include "../data/gradient_index.h" // for GHistIndexMatrix #include "quantile.h" #include "xgboost/base.h" -#include "xgboost/context.h" // Context -#include "xgboost/data.h" // SparsePage, SortedCSCPage +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for SparsePage, SortedCSCPage #if defined(XGBOOST_MM_PREFETCH_PRESENT) #include @@ -32,7 +32,7 @@ HistogramCuts::HistogramCuts() { } HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins, bool use_sorted, - Span const hessian) { + Span hessian) { HistogramCuts out; auto const &info = m->Info(); auto n_threads = ctx->Threads(); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index eabdb86de278..2dfba72158bb 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -19,14 +19,13 @@ #include #include "categorical.h" +#include "cuda_context.cuh" // for CUDAContext #include "device_helpers.cuh" #include "hist_util.cuh" #include "hist_util.h" -#include "math.h" // NOLINT #include "quantile.h" #include "xgboost/host_device_vector.h" - namespace xgboost::common { constexpr float SketchContainer::kFactor; @@ -109,22 +108,19 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_ro return std::min(sketch_batch_num_elements, kIntMax); } -void SortByWeight(dh::device_vector* weights, - dh::device_vector* sorted_entries) { +void SortByWeight(dh::device_vector* weights, dh::device_vector* sorted_entries) { // Sort both entries and wegihts. dh::XGBDeviceAllocator alloc; - thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), - sorted_entries->end(), weights->begin(), - detail::EntryCompareOp()); + CHECK_EQ(weights->size(), sorted_entries->size()); + thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(), + weights->begin(), detail::EntryCompareOp()); // Scan weights dh::XGBCachingDeviceAllocator caching; - thrust::inclusive_scan_by_key(thrust::cuda::par(caching), - sorted_entries->begin(), sorted_entries->end(), - weights->begin(), weights->begin(), - [=] __device__(const Entry& a, const Entry& b) { - return a.index == b.index; - }); + thrust::inclusive_scan_by_key( + thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(), + weights->begin(), + [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); } void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span d_cuts_ptr, @@ -200,159 +196,170 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span alloc; +void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo const& info, + std::size_t begin, std::size_t end, + SketchContainer* sketch_container, // <- output sketch + int num_cuts_per_feature, common::Span sample_weight) { dh::device_vector sorted_entries; if (page.data.DeviceCanRead()) { - const auto& device_data = page.data.ConstDevicePointer(); - sorted_entries = dh::device_vector(device_data + begin, device_data + end); + // direct copy if data is already on device + auto const& d_data = page.data.ConstDevicePointer(); + sorted_entries = dh::device_vector(d_data + begin, d_data + end); + } else { + const auto& h_data = page.data.ConstHostVector(); + sorted_entries = dh::device_vector(h_data.begin() + begin, h_data.begin() + end); + } + + bst_row_t base_rowid = page.base_rowid; + + dh::device_vector entry_weight; + auto cuctx = ctx->CUDACtx(); + if (!sample_weight.empty()) { + // Expand sample weight into entry weight. + CHECK_EQ(sample_weight.size(), info.num_row_); + entry_weight.resize(sorted_entries.size()); + auto d_temp_weight = dh::ToSpan(entry_weight); + page.offset.SetDevice(ctx->Device()); + auto row_ptrs = page.offset.ConstDeviceSpan(); + thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), entry_weight.size(), + [=] __device__(std::size_t idx) { + std::size_t element_idx = idx + begin; + std::size_t ridx = dh::SegmentId(row_ptrs, element_idx); + d_temp_weight[idx] = sample_weight[ridx + base_rowid]; + }); + detail::SortByWeight(&entry_weight, &sorted_entries); } else { - const auto& host_data = page.data.ConstHostVector(); - sorted_entries = dh::device_vector(host_data.begin() + begin, - host_data.begin() + end); + thrust::sort(cuctx->CTP(), sorted_entries.begin(), sorted_entries.end(), + detail::EntryCompareOp()); } - thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), detail::EntryCompareOp()); HostDeviceVector cuts_ptr; dh::caching_device_vector column_sizes_scan; data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); auto batch_it = dh::MakeTransformIterator( - sorted_entries.data().get(), - [] __device__(Entry const &e) -> data::COOTuple { - return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size. + sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple { + return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. }); - detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, + detail::GetColumnSizesScan(ctx->Ordinal(), info.num_col_, num_cuts_per_feature, IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - if (sketch_container->HasCategorical()) { - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr, + auto p_weight = entry_weight.empty() ? nullptr : &entry_weight; + detail::RemoveDuplicatedCategories(ctx->Ordinal(), info, d_cuts_ptr, &sorted_entries, p_weight, &column_sizes_scan); } auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); - // add cuts into sketches - sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), - d_cuts_ptr, h_cuts_ptr.back()); + // Add cuts into sketches + sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back(), dh::ToSpan(entry_weight)); + sorted_entries.clear(); sorted_entries.shrink_to_fit(); CHECK_EQ(sorted_entries.capacity(), 0); CHECK_NE(cuts_ptr.Size(), 0); } -void ProcessWeightedBatch(int device, const SparsePage& page, - MetaInfo const& info, size_t begin, size_t end, - SketchContainer* sketch_container, int num_cuts_per_feature, - size_t num_columns, - bool is_ranking, Span d_group_ptr) { - auto weights = info.weights_.ConstDeviceSpan(); - - dh::XGBCachingDeviceAllocator alloc; - const auto& host_data = page.data.ConstHostVector(); - dh::device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); - - // Binary search to assign weights to each element - dh::device_vector temp_weights(sorted_entries.size()); - auto d_temp_weights = temp_weights.data().get(); - page.offset.SetDevice(device); - auto row_ptrs = page.offset.ConstDeviceSpan(); - size_t base_rowid = page.base_rowid; - if (is_ranking) { - CHECK_GE(d_group_ptr.size(), 2) - << "Must have at least 1 group for ranking."; - CHECK_EQ(weights.size(), d_group_ptr.size() - 1) - << "Weight size should equal to number of groups."; - dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { - size_t element_idx = idx + begin; - size_t ridx = dh::SegmentId(row_ptrs, element_idx); - bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); - d_temp_weights[idx] = weights[group_idx]; - }); - } else { - dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { - size_t element_idx = idx + begin; - size_t ridx = dh::SegmentId(row_ptrs, element_idx); - d_temp_weights[idx] = weights[ridx + base_rowid]; - }); +// Unify group weight, Hessian, and sample weight into sample weight. +[[nodiscard]] Span UnifyWeight(CUDAContext const* cuctx, MetaInfo const& info, + common::Span hessian, + HostDeviceVector* p_out_weight) { + if (hessian.empty()) { + if (info.IsRanking() && !info.weights_.Empty()) { + common::Span group_weight = info.weights_.ConstDeviceSpan(); + dh::device_vector group_ptr(info.group_ptr_); + auto d_group_ptr = dh::ToSpan(group_ptr); + CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking."; + auto d_weight = info.weights_.ConstDeviceSpan(); + CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1) + << "Weight size should equal to number of groups."; + p_out_weight->Resize(info.num_row_); + auto d_weight_out = p_out_weight->DeviceSpan(); + + thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), d_weight_out.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto gidx = dh::SegmentId(d_group_ptr, i); + d_weight_out[i] = d_weight[gidx]; + }); + return p_out_weight->ConstDeviceSpan(); + } else { + return info.weights_.ConstDeviceSpan(); + } } - detail::SortByWeight(&temp_weights, &sorted_entries); - HostDeviceVector cuts_ptr; - dh::caching_device_vector column_sizes_scan; - data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); - auto batch_it = dh::MakeTransformIterator( - sorted_entries.data().get(), - [] __device__(Entry const &e) -> data::COOTuple { - return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. - }); - detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, - IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr, - &column_sizes_scan); - auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - if (sketch_container->HasCategorical()) { - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights, - &column_sizes_scan); + // sketch with hessian as weight + p_out_weight->Resize(info.num_row_); + auto d_weight_out = p_out_weight->DeviceSpan(); + if (!info.weights_.Empty()) { + // merge sample weight with hessian + auto d_weight = info.weights_.ConstDeviceSpan(); + if (info.IsRanking()) { + dh::device_vector group_ptr(info.group_ptr_); + CHECK_EQ(hessian.size(), d_weight_out.size()); + auto d_group_ptr = dh::ToSpan(group_ptr); + CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking."; + CHECK_EQ(d_weight.size(), d_group_ptr.size() - 1) + << "Weight size should equal to number of groups."; + thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + d_weight_out[i] = d_weight[dh::SegmentId(d_group_ptr, i)] * hessian(i); + }); + } else { + CHECK_EQ(hessian.size(), info.num_row_); + CHECK_EQ(hessian.size(), d_weight.size()); + CHECK_EQ(hessian.size(), d_weight_out.size()); + thrust::for_each_n( + cuctx->CTP(), thrust::make_counting_iterator(0ul), hessian.size(), + [=] XGBOOST_DEVICE(std::size_t i) { d_weight_out[i] = d_weight[i] * hessian(i); }); + } + } else { + // copy hessian as weight + CHECK_EQ(d_weight_out.size(), hessian.size()); + dh::safe_cuda(cudaMemcpyAsync(d_weight_out.data(), hessian.data(), hessian.size_bytes(), + cudaMemcpyDefault)); } + return d_weight_out; +} - auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); +HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, + Span hessian, + std::size_t sketch_batch_num_elements) { + auto const& info = p_fmat->Info(); + bool has_weight = !info.weights_.Empty(); + info.feature_types.SetDevice(ctx->Device()); - // Extract cuts - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, - h_cuts_ptr.back(), dh::ToSpan(temp_weights)); - sorted_entries.clear(); - sorted_entries.shrink_to_fit(); -} + HostDeviceVector weight; + weight.SetDevice(ctx->Device()); -HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, - size_t sketch_batch_num_elements) { - dmat->Info().feature_types.SetDevice(device); - dmat->Info().feature_types.ConstDevicePointer(); // pull to device early // Configure batch size based on available memory - bool has_weights = dmat->Info().weights_.Size() > 0; - size_t num_cuts_per_feature = - detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_); + std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_); sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - dmat->Info().num_row_, - dmat->Info().num_col_, - dmat->Info().num_nonzero_, - device, num_cuts_per_feature, has_weights); + sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(), + num_cuts_per_feature, has_weight); + + CUDAContext const* cuctx = ctx->CUDACtx(); + + info.weights_.SetDevice(ctx->Device()); + auto d_weight = UnifyWeight(cuctx, info, hessian, &weight); HistogramCuts cuts; - SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_, - dmat->Info().num_row_, device); - - dmat->Info().weights_.SetDevice(device); - for (const auto& batch : dmat->GetBatches()) { - size_t batch_nnz = batch.data.Size(); - auto const& info = dmat->Info(); - for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { - size_t end = std::min(batch_nnz, static_cast(begin + sketch_batch_num_elements)); - if (has_weights) { - bool is_ranking = HostSketchContainer::UseGroup(dmat->Info()); - dh::caching_device_vector groups(info.group_ptr_.cbegin(), - info.group_ptr_.cend()); - ProcessWeightedBatch( - device, batch, dmat->Info(), begin, end, - &sketch_container, - num_cuts_per_feature, - dmat->Info().num_col_, - is_ranking, dh::ToSpan(groups)); - } else { - ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container, - num_cuts_per_feature, dmat->Info().num_col_); - } + SketchContainer sketch_container(info.feature_types, max_bin, info.num_col_, info.num_row_, + ctx->Ordinal()); + CHECK_EQ(has_weight || !hessian.empty(), !d_weight.empty()); + for (const auto& page : p_fmat->GetBatches()) { + std::size_t page_nnz = page.data.Size(); + for (auto begin = 0ull; begin < page_nnz; begin += sketch_batch_num_elements) { + std::size_t end = + std::min(page_nnz, static_cast(begin + sketch_batch_num_elements)); + ProcessWeightedBatch(ctx, page, info, begin, end, &sketch_container, num_cuts_per_feature, + d_weight); } } - sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit()); + + sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit()); return cuts; } } // namespace xgboost::common diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 0dcdad64dbc1..f13f01b3e9ed 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -11,14 +11,13 @@ #include // for size_t -#include "../data/device_adapter.cuh" +#include "../data/adapter.h" // for IsValidFunctor #include "device_helpers.cuh" #include "hist_util.h" #include "quantile.cuh" -#include "timer.h" +#include "xgboost/span.h" // for IterSpan -namespace xgboost { -namespace common { +namespace xgboost::common { namespace cuda { /** * copy and paste of the host version, we can't make it a __host__ __device__ function as @@ -246,10 +245,35 @@ void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span* p_column_sizes_scan); } // namespace detail -// Compute sketch on DMatrix. -// sketch_batch_num_elements 0 means autodetect. Only modify this for testing. -HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, - size_t sketch_batch_num_elements = 0); +/** + * @brief Compute sketch on DMatrix with GPU and Hessian as weight. + * + * @param ctx Runtime context + * @param p_fmat Training feature matrix + * @param max_bin Maximum number of bins for each feature + * @param hessian Hessian vector. + * @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing. + * + * @return Quantile cuts + */ +HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, + Span hessian, + std::size_t sketch_batch_num_elements = 0); + +/** + * @brief Compute sketch on DMatrix with GPU. + * + * @param ctx Runtime context + * @param p_fmat Training feature matrix + * @param max_bin Maximum number of bins for each feature + * @param sketch_batch_num_elements 0 means autodetect. Only modify this for testing. + * + * @return Quantile cuts + */ +inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, + std::size_t sketch_batch_num_elements = 0) { + return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements); +} template void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, @@ -417,7 +441,5 @@ void AdapterDeviceSketch(Batch batch, int num_bins, } } } -} // namespace common -} // namespace xgboost - +} // namespace xgboost::common #endif // COMMON_HIST_UTIL_CUH_ diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 2781da8e0cff..18f208467a56 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -174,7 +174,7 @@ class HistogramCuts { * but consumes more memory. */ HistogramCuts SketchOnDMatrix(Context const* ctx, DMatrix* m, bst_bin_t max_bins, - bool use_sorted = false, Span const hessian = {}); + bool use_sorted = false, Span hessian = {}); enum BinTypeSize : uint8_t { kUint8BinsTypeSize = 1, diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 55c0ecf202f8..175a5cbf1b10 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -168,6 +168,9 @@ bool HostDeviceVector::DeviceCanWrite() const { template void HostDeviceVector::SetDevice(int) const {} +template +void HostDeviceVector::SetDevice(DeviceOrd) const {} + // explicit instantiations are required, as HostDeviceVector isn't header-only template class HostDeviceVector; template class HostDeviceVector; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 1fa9a3b2200c..7acb6719ba91 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -394,6 +394,11 @@ void HostDeviceVector::SetDevice(int device) const { impl_->SetDevice(device); } +template +void HostDeviceVector::SetDevice(DeviceOrd device) const { + impl_->SetDevice(device.ordinal); +} + template void HostDeviceVector::Resize(size_t new_size, T v) { impl_->Resize(new_size, v); diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 0ccd7a08138b..7097df405f54 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -131,7 +131,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP monitor_.Start("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts. row_stride = GetRowStride(dmat); - cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin); + cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin); monitor_.Stop("Quantiles"); monitor_.Start("InitCompressedData"); diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 1d47ae9e6c63..1ee1bd60ba09 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -21,7 +21,7 @@ GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique hess) + common::Span hess) : max_numeric_bins_per_feat{max_bins_per_feat} { CHECK(p_fmat->SingleColBlock()); // We use sorted sketching for approx tree method since it's more efficient in diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 901451ad908e..0bb93fc20900 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -160,7 +160,7 @@ class GHistIndexMatrix { * \brief Constrcutor for SimpleDMatrix. */ GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat, - double sparse_thresh, bool sorted_sketch, common::Span hess = {}); + double sparse_thresh, bool sorted_sketch, common::Span hess = {}); /** * \brief Constructor for Iterative DMatrix. Initialize basic information and prepare * for push batch. diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 38304f72509e..1d9af9f06d25 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -25,8 +25,8 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, cache_info_.erase(id); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); std::unique_ptr cuts; - cuts = std::make_unique( - common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)); + cuts = + std::make_unique(common::DeviceSketch(ctx, this, param.max_bin, 0)); this->InitializeSparsePage(ctx); // reset after use. row_stride = GetRowStride(this); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 2d5735925565..304e8567e218 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -3,17 +3,22 @@ */ #include #include +#include // for bst_bin_t #include #include -#include -#include +#include // for transform +#include // for floor +#include // for size_t +#include // for numeric_limits +#include // for string, to_string +#include // for tuple, make_tuple +#include // for vector #include "../../../include/xgboost/logging.h" #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.h" -#include "../../../src/common/math.h" #include "../../../src/data/device_adapter.cuh" #include "../../../src/data/simple_dmatrix.h" #include "../data/test_array_interface.h" @@ -21,8 +26,7 @@ #include "../helpers.h" #include "test_hist_util.h" -namespace xgboost { -namespace common { +namespace xgboost::common { template HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) { @@ -32,16 +36,17 @@ HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, f } TEST(HistUtil, DeviceSketch) { + auto ctx = MakeCUDACtx(0); int num_columns = 1; int num_bins = 4; std::vector x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f}; int num_rows = x.size(); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); - auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); + auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins); - Context ctx; - HistogramCuts host_cuts = SketchOnDMatrix(&ctx, dmat.get(), num_bins); + Context cpu_ctx; + HistogramCuts host_cuts = SketchOnDMatrix(&cpu_ctx, dmat.get(), num_bins); EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); @@ -65,6 +70,7 @@ TEST(HistUtil, SketchBatchNumElements) { } TEST(HistUtil, DeviceSketchMemory) { + auto ctx = MakeCUDACtx(0); int num_columns = 100; int num_rows = 1000; int num_bins = 256; @@ -73,7 +79,7 @@ TEST(HistUtil, DeviceSketchMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); - auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); + auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins); size_t bytes_required = detail::RequiredMemory( num_rows, num_columns, num_rows * num_columns, num_bins, false); @@ -83,6 +89,7 @@ TEST(HistUtil, DeviceSketchMemory) { } TEST(HistUtil, DeviceSketchWeightsMemory) { + auto ctx = MakeCUDACtx(0); int num_columns = 100; int num_rows = 1000; int num_bins = 256; @@ -92,7 +99,7 @@ TEST(HistUtil, DeviceSketchWeightsMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); - auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); + auto device_cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ConsoleLogger::Configure({{"verbosity", "0"}}); size_t bytes_required = detail::RequiredMemory( @@ -102,43 +109,44 @@ TEST(HistUtil, DeviceSketchWeightsMemory) { } TEST(HistUtil, DeviceSketchDeterminism) { + auto ctx = MakeCUDACtx(0); int num_rows = 500; int num_columns = 5; int num_bins = 256; auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); - auto reference_sketch = DeviceSketch(0, dmat.get(), num_bins); + auto reference_sketch = DeviceSketch(&ctx, dmat.get(), num_bins); size_t constexpr kRounds{ 100 }; for (size_t r = 0; r < kRounds; ++r) { - auto new_sketch = DeviceSketch(0, dmat.get(), num_bins); + auto new_sketch = DeviceSketch(&ctx, dmat.get(), num_bins); ASSERT_EQ(reference_sketch.Values(), new_sketch.Values()); ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues()); } } TEST(HistUtil, DeviceSketchCategoricalAsNumeric) { - int categorical_sizes[] = {2, 6, 8, 12}; + auto ctx = MakeCUDACtx(0); + auto categorical_sizes = {2, 6, 8, 12}; int num_bins = 256; - int sizes[] = {25, 100, 1000}; + auto sizes = {25, 100, 1000}; for (auto n : sizes) { for (auto num_categories : categorical_sizes) { auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); auto dmat = GetDMatrixFromData(x, n, 1); - auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } } TEST(HistUtil, DeviceSketchCategoricalFeatures) { - TestCategoricalSketch(1000, 256, 32, false, - [](DMatrix *p_fmat, int32_t num_bins) { - return DeviceSketch(0, p_fmat, num_bins); - }); - TestCategoricalSketch(1000, 256, 32, true, - [](DMatrix *p_fmat, int32_t num_bins) { - return DeviceSketch(0, p_fmat, num_bins); - }); + auto ctx = MakeCUDACtx(0); + TestCategoricalSketch(1000, 256, 32, false, [ctx](DMatrix* p_fmat, int32_t num_bins) { + return DeviceSketch(&ctx, p_fmat, num_bins); + }); + TestCategoricalSketch(1000, 256, 32, true, [ctx](DMatrix* p_fmat, int32_t num_bins) { + return DeviceSketch(&ctx, p_fmat, num_bins); + }); } void TestMixedSketch() { @@ -162,7 +170,8 @@ void TestMixedSketch() { m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); m->Info().feature_types.HostVector().push_back(FeatureType::kNumerical); - auto cuts = DeviceSketch(0, m.get(), n_bins); + auto ctx = MakeCUDACtx(0); + auto cuts = DeviceSketch(&ctx, m.get(), n_bins); ASSERT_EQ(cuts.Values().size(), n_bins + n_categories); } @@ -234,37 +243,40 @@ TEST(HistUtil, RemoveDuplicatedCategories) { } TEST(HistUtil, DeviceSketchMultipleColumns) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { - auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } } TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); for (auto num_bins : bin_sizes) { - auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } } TEST(HistUitl, DeviceSketchWeights) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); @@ -274,8 +286,8 @@ TEST(HistUitl, DeviceSketchWeights) { h_weights.resize(num_rows); std::fill(h_weights.begin(), h_weights.end(), 1.0f); for (auto num_bins : bin_sizes) { - auto cuts = DeviceSketch(0, dmat.get(), num_bins); - auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); + auto wcuts = DeviceSketch(&ctx, weighted_dmat.get(), num_bins); ASSERT_EQ(cuts.MinValues(), wcuts.MinValues()); ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs()); ASSERT_EQ(cuts.Values(), wcuts.Values()); @@ -286,14 +298,15 @@ TEST(HistUitl, DeviceSketchWeights) { } TEST(HistUtil, DeviceSketchBatches) { + auto ctx = MakeCUDACtx(0); int num_bins = 256; int num_rows = 5000; - int batch_sizes[] = {0, 100, 1500, 6000}; + auto batch_sizes = {0, 100, 1500, 6000}; int num_columns = 5; for (auto batch_size : batch_sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); - auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, batch_size); ValidateCuts(cuts, dmat.get(), num_bins); } @@ -301,8 +314,8 @@ TEST(HistUtil, DeviceSketchBatches) { size_t batches = 16; auto x = GenerateRandom(num_rows * batches, num_columns); auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns); - auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows); - auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + auto cuts_with_batches = DeviceSketch(&ctx, dmat.get(), num_bins, num_rows); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins, 0); auto const& cut_values_batched = cuts_with_batches.Values(); auto const& cut_values = cuts.Values(); @@ -313,15 +326,16 @@ TEST(HistUtil, DeviceSketchBatches) { } TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns =5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); dmlc::TemporaryDirectory temp; auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp); for (auto num_bins : bin_sizes) { - auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -329,8 +343,9 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { // See https://github.com/dmlc/xgboost/issues/5866. TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; dmlc::TemporaryDirectory temp; for (auto num_rows : sizes) { @@ -338,7 +353,7 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, temp); dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); for (auto num_bins : bin_sizes) { - auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto cuts = DeviceSketch(&ctx, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -504,9 +519,9 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, } TEST(HistUtil, AdapterDeviceSketchCategorical) { - int categorical_sizes[] = {2, 6, 8, 12}; + auto categorical_sizes = {2, 6, 8, 12}; int num_bins = 256; - int sizes[] = {25, 100, 1000}; + auto sizes = {25, 100, 1000}; for (auto n : sizes) { for (auto num_categories : categorical_sizes) { auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); @@ -521,8 +536,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) { } TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); @@ -538,7 +553,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { TEST(HistUtil, AdapterDeviceSketchBatches) { int num_bins = 256; int num_rows = 5000; - int batch_sizes[] = {0, 100, 1500, 6000}; + auto batch_sizes = {0, 100, 1500, 6000}; int num_columns = 5; for (auto batch_size : batch_sizes) { auto x = GenerateRandom(num_rows, num_columns); @@ -619,14 +634,15 @@ TEST(HistUtil, GetColumnSize) { // Check sketching from adapter or DMatrix results in the same answer // Consistency here is useful for testing and user experience TEST(HistUtil, SketchingEquivalent) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + auto ctx = MakeCUDACtx(0); + auto bin_sizes = {2, 16, 256, 512}; + auto sizes = {100, 1000, 1500}; int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { - auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins); + auto dmat_cuts = DeviceSketch(&ctx, dmat.get(), num_bins); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); common::HistogramCuts adapter_cuts = MakeUnweightedCutsForTest( @@ -641,21 +657,25 @@ TEST(HistUtil, SketchingEquivalent) { } TEST(HistUtil, DeviceSketchFromGroupWeights) { + auto ctx = MakeCUDACtx(0); size_t constexpr kRows = 3000, kCols = 200, kBins = 256; size_t constexpr kGroups = 10; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); + + // sketch with group weight auto& h_weights = m->Info().weights_.HostVector(); - h_weights.resize(kRows); + h_weights.resize(kGroups); std::fill(h_weights.begin(), h_weights.end(), 1.0f); std::vector groups(kGroups); for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); - HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0); + HistogramCuts weighted_cuts = DeviceSketch(&ctx, m.get(), kBins, 0); + // sketch with no weight h_weights.clear(); - HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0); + HistogramCuts cuts = DeviceSketch(&ctx, m.get(), kBins, 0); ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size()); ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size()); @@ -723,9 +743,10 @@ void TestAdapterSketchFromWeights(bool with_group) { ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); ValidateCuts(cuts, dmat.get(), kBins); + auto cuda_ctx = MakeCUDACtx(0); if (with_group) { dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight - HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); + HistogramCuts non_weighted = DeviceSketch(&cuda_ctx, dmat.get(), kBins, 0); for (size_t i = 0; i < cuts.Values().size(); ++i) { ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]); } @@ -760,5 +781,150 @@ TEST(HistUtil, AdapterSketchFromWeights) { TestAdapterSketchFromWeights(false); TestAdapterSketchFromWeights(true); } -} // namespace common -} // namespace xgboost + +namespace { +class DeviceSketchWithHessianTest + : public ::testing::TestWithParam> { + bst_feature_t n_features_ = 5; + bst_group_t n_groups_{3}; + + auto GenerateHessian(Context const* ctx, bst_row_t n_samples) const { + HostDeviceVector hessian; + auto& h_hess = hessian.HostVector(); + h_hess = GenerateRandomWeights(n_samples); + std::mt19937 rng(0); + std::shuffle(h_hess.begin(), h_hess.end(), rng); + hessian.SetDevice(ctx->Device()); + return hessian; + } + + void Check(Context const* ctx, std::shared_ptr p_fmat, bst_bin_t n_bins, + HostDeviceVector const& hessian, std::vector const& w) const { + auto const& h_hess = hessian.ConstHostVector(); + { + auto& h_weight = p_fmat->Info().weights_.HostVector(); + h_weight = w; + } + + HistogramCuts cuts_hess = + DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + ValidateCuts(cuts_hess, p_fmat.get(), n_bins); + + // merge hessian + { + auto& h_weight = p_fmat->Info().weights_.HostVector(); + ASSERT_EQ(h_weight.size(), h_hess.size()); + for (std::size_t i = 0; i < h_weight.size(); ++i) { + h_weight[i] = w[i] * h_hess[i]; + } + } + + HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins); + ValidateCuts(cuts_wh, p_fmat.get(), n_bins); + ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); + for (std::size_t i = 0; i < cuts_hess.Values().size(); ++i) { + ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps); + } + + p_fmat->Info().weights_.HostVector() = w; + } + + protected: + Context ctx_ = MakeCUDACtx(0); + + void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins) const { + auto x = GenerateRandom(n_samples, n_features_); + + std::vector gptr; + gptr.resize(n_groups_ + 1, 0); + gptr[1] = n_samples / n_groups_; + gptr[2] = n_samples / n_groups_ + gptr[1]; + gptr.back() = n_samples; + + auto hessian = this->GenerateHessian(ctx, n_samples); + auto const& h_hess = hessian.ConstHostVector(); + auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_); + p_fmat->Info().group_ptr_ = gptr; + + // test with constant group weight + std::vector w(n_groups_, 1.0f); + p_fmat->Info().weights_.HostVector() = w; + HistogramCuts cuts_hess = + DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + // make validation easier by converting it into sample weight. + p_fmat->Info().weights_.HostVector() = h_hess; + p_fmat->Info().group_ptr_.clear(); + ValidateCuts(cuts_hess, p_fmat.get(), n_bins); + // restore ltr properties + p_fmat->Info().weights_.HostVector() = w; + p_fmat->Info().group_ptr_ = gptr; + + // test with random group weight + w = GenerateRandomWeights(n_groups_); + p_fmat->Info().weights_.HostVector() = w; + cuts_hess = DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + // make validation easier by converting it into sample weight. + p_fmat->Info().weights_.HostVector() = h_hess; + p_fmat->Info().group_ptr_.clear(); + ValidateCuts(cuts_hess, p_fmat.get(), n_bins); + + // merge hessian with sample weight + p_fmat->Info().weights_.Resize(n_samples); + p_fmat->Info().group_ptr_.clear(); + for (std::size_t i = 0; i < h_hess.size(); ++i) { + auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i); + p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i]; + } + auto cuts = DeviceSketch(ctx, p_fmat.get(), n_bins); + ValidateCuts(cuts, p_fmat.get(), n_bins); + ASSERT_EQ(cuts.Values().size(), cuts_hess.Values().size()); + for (std::size_t i = 0; i < cuts.Values().size(); ++i) { + EXPECT_NEAR(cuts.Values()[i], cuts_hess.Values()[i], 1e-4f); + } + } + + void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins) const { + auto x = GenerateRandom(n_samples, n_features_); + auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_); + std::vector w = GenerateRandomWeights(n_samples); + + auto hessian = this->GenerateHessian(ctx, n_samples); + + this->Check(ctx, p_fmat, n_bins, hessian, w); + } +}; + +auto MakeParamsForTest() { + std::vector sizes = {1, 2, 256, 512, 1000, 1500}; + std::vector bin_sizes = {2, 16, 256, 512}; + std::vector> configs; + for (auto n_samples : sizes) { + for (auto n_bins : bin_sizes) { + configs.emplace_back(true, n_samples, n_bins); + configs.emplace_back(false, n_samples, n_bins); + } + } + return configs; +} +} // namespace + +TEST_P(DeviceSketchWithHessianTest, DeviceSketchWithHessian) { + auto param = GetParam(); + auto n_samples = std::get<1>(param); + auto n_bins = std::get<2>(param); + if (std::get<0>(param)) { + this->TestLTR(&ctx_, n_samples, n_bins); + } else { + this->TestRegression(&ctx_, n_samples, n_bins); + } +} + +INSTANTIATE_TEST_SUITE_P( + HistUtil, DeviceSketchWithHessianTest, ::testing::ValuesIn(MakeParamsForTest()), + [](::testing::TestParamInfo const& info) { + auto task = std::get<0>(info.param) ? "ltr" : "reg"; + auto n_samples = std::to_string(std::get<1>(info.param)); + auto n_bins = std::to_string(std::get<2>(info.param)); + return std::string{task} + "_" + n_samples + "_" + n_bins; + }); +} // namespace xgboost::common diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index d2dc802a93e4..eda55ee479da 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,9 +1,14 @@ +/** + * Copyright 2020-2023, XGBoost contributors + */ #include -#include "test_quantile.h" -#include "../helpers.h" + #include "../../../src/collective/communicator-inl.cuh" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" +#include "../../../src/data/device_adapter.cuh" // CupyAdapter +#include "../helpers.h" +#include "test_quantile.h" namespace xgboost { namespace { @@ -437,13 +442,13 @@ void TestColumnSplitBasic() { }()}; // Generate cuts for distributed environment. - auto const device = rank; - HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins); + auto ctx = MakeCUDACtx(rank); + HistogramCuts distributed_cuts = common::DeviceSketch(&ctx, m.get(), kBins); // Generate cuts for single node environment collective::Finalize(); CHECK_EQ(collective::GetWorldSize(), 1); - HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins); + HistogramCuts single_node_cuts = common::DeviceSketch(&ctx, m.get(), kBins); auto const& sptrs = single_node_cuts.Ptrs(); auto const& dptrs = distributed_cuts.Ptrs(); From ba616fa12dfb1559b2fdb91bf20954c093a9b8ae Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 19 Jul 2023 17:53:02 +0800 Subject: [PATCH 2/2] More tests. --- tests/cpp/common/test_hist_util.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 304e8567e218..91baad981f64 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -798,8 +798,9 @@ class DeviceSketchWithHessianTest return hessian; } - void Check(Context const* ctx, std::shared_ptr p_fmat, bst_bin_t n_bins, - HostDeviceVector const& hessian, std::vector const& w) const { + void CheckReg(Context const* ctx, std::shared_ptr p_fmat, bst_bin_t n_bins, + HostDeviceVector const& hessian, std::vector const& w, + std::size_t n_elements) const { auto const& h_hess = hessian.ConstHostVector(); { auto& h_weight = p_fmat->Info().weights_.HostVector(); @@ -807,7 +808,7 @@ class DeviceSketchWithHessianTest } HistogramCuts cuts_hess = - DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements); ValidateCuts(cuts_hess, p_fmat.get(), n_bins); // merge hessian @@ -819,7 +820,7 @@ class DeviceSketchWithHessianTest } } - HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins); + HistogramCuts cuts_wh = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements); ValidateCuts(cuts_wh, p_fmat.get(), n_bins); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); for (std::size_t i = 0; i < cuts_hess.Values().size(); ++i) { @@ -832,7 +833,8 @@ class DeviceSketchWithHessianTest protected: Context ctx_ = MakeCUDACtx(0); - void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins) const { + void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins, + std::size_t n_elements) const { auto x = GenerateRandom(n_samples, n_features_); std::vector gptr; @@ -850,7 +852,7 @@ class DeviceSketchWithHessianTest std::vector w(n_groups_, 1.0f); p_fmat->Info().weights_.HostVector() = w; HistogramCuts cuts_hess = - DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements); // make validation easier by converting it into sample weight. p_fmat->Info().weights_.HostVector() = h_hess; p_fmat->Info().group_ptr_.clear(); @@ -862,7 +864,8 @@ class DeviceSketchWithHessianTest // test with random group weight w = GenerateRandomWeights(n_groups_); p_fmat->Info().weights_.HostVector() = w; - cuts_hess = DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan()); + cuts_hess = + DeviceSketchWithHessian(ctx, p_fmat.get(), n_bins, hessian.ConstDeviceSpan(), n_elements); // make validation easier by converting it into sample weight. p_fmat->Info().weights_.HostVector() = h_hess; p_fmat->Info().group_ptr_.clear(); @@ -875,7 +878,7 @@ class DeviceSketchWithHessianTest auto gidx = dh::SegmentId(Span{gptr.data(), gptr.size()}, i); p_fmat->Info().weights_.HostVector()[i] = w[gidx] * h_hess[i]; } - auto cuts = DeviceSketch(ctx, p_fmat.get(), n_bins); + auto cuts = DeviceSketch(ctx, p_fmat.get(), n_bins, n_elements); ValidateCuts(cuts, p_fmat.get(), n_bins); ASSERT_EQ(cuts.Values().size(), cuts_hess.Values().size()); for (std::size_t i = 0; i < cuts.Values().size(); ++i) { @@ -883,14 +886,15 @@ class DeviceSketchWithHessianTest } } - void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins) const { + void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins, + std::size_t n_elements) const { auto x = GenerateRandom(n_samples, n_features_); auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_); std::vector w = GenerateRandomWeights(n_samples); auto hessian = this->GenerateHessian(ctx, n_samples); - this->Check(ctx, p_fmat, n_bins, hessian, w); + this->CheckReg(ctx, p_fmat, n_bins, hessian, w, n_elements); } }; @@ -913,9 +917,11 @@ TEST_P(DeviceSketchWithHessianTest, DeviceSketchWithHessian) { auto n_samples = std::get<1>(param); auto n_bins = std::get<2>(param); if (std::get<0>(param)) { - this->TestLTR(&ctx_, n_samples, n_bins); + this->TestLTR(&ctx_, n_samples, n_bins, 0); + this->TestLTR(&ctx_, n_samples, n_bins, 512); } else { - this->TestRegression(&ctx_, n_samples, n_bins); + this->TestRegression(&ctx_, n_samples, n_bins, 0); + this->TestRegression(&ctx_, n_samples, n_bins, 512); } }