diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 6599d4b5a30e..631193db4d86 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -70,7 +70,7 @@ NcclDeviceCommunicator::~NcclDeviceCommunicator() { namespace { ncclDataType_t GetNcclDataType(DataType const &data_type) { - ncclDataType_t result; + ncclDataType_t result{ncclInt8}; switch (data_type) { case DataType::kInt8: result = ncclInt8; @@ -108,7 +108,7 @@ bool IsBitwiseOp(Operation const &op) { } ncclRedOp_t GetNcclRedOp(Operation const &op) { - ncclRedOp_t result; + ncclRedOp_t result{ncclMax}; switch (op) { case Operation::kMax: result = ncclMax; diff --git a/src/common/bitfield.h b/src/common/bitfield.h index 6bb5f3404ba7..6cdf4412eae3 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 by Contributors +/** + * Copyright 2019-2023, XGBoost Contributors * \file bitfield.h */ #ifndef XGBOOST_COMMON_BITFIELD_H_ @@ -50,14 +50,17 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr } #endif // defined(__CUDACC__) -/*! - * \brief A non-owning type with auxiliary methods defined for manipulating bits. +/** + * @brief A non-owning type with auxiliary methods defined for manipulating bits. * - * \tparam Direction Whether the bits start from left or from right. + * @tparam VT Underlying value type, must be an unsigned integer. + * @tparam Direction Whether the bits start from left or from right. + * @tparam IsConst Whether the view is const. */ template struct BitFieldContainer { using value_type = std::conditional_t; // NOLINT + using size_type = size_t; // NOLINT using index_type = size_t; // NOLINT using pointer = value_type*; // NOLINT @@ -70,8 +73,9 @@ struct BitFieldContainer { }; private: - common::Span bits_; - static_assert(!std::is_signed::value, "Must use unsiged type as underlying storage."); + value_type* bits_{nullptr}; + size_type n_values_{0}; + static_assert(!std::is_signed::value, "Must use an unsiged type as the underlying storage."); public: XGBOOST_DEVICE static Pos ToBitPos(index_type pos) { @@ -86,13 +90,15 @@ struct BitFieldContainer { public: BitFieldContainer() = default; - XGBOOST_DEVICE explicit BitFieldContainer(common::Span bits) : bits_{bits} {} - XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {} + XGBOOST_DEVICE explicit BitFieldContainer(common::Span bits) + : bits_{bits.data()}, n_values_{bits.size()} {} + BitFieldContainer(BitFieldContainer const& other) = default; + BitFieldContainer(BitFieldContainer&& other) = default; BitFieldContainer &operator=(BitFieldContainer const &that) = default; BitFieldContainer &operator=(BitFieldContainer &&that) = default; - XGBOOST_DEVICE common::Span Bits() { return bits_; } - XGBOOST_DEVICE common::Span Bits() const { return bits_; } + XGBOOST_DEVICE auto Bits() { return common::Span{bits_, NumValues()}; } + XGBOOST_DEVICE auto Bits() const { return common::Span{bits_, NumValues()}; } /*\brief Compute the size of needed memory allocation. The returned value is in terms * of number of elements with `BitFieldContainer::value_type'. @@ -103,17 +109,17 @@ struct BitFieldContainer { #if defined(__CUDA_ARCH__) __device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) { auto tid = blockIdx.x * blockDim.x + threadIdx.x; - size_t min_size = min(bits_.size(), rhs.bits_.size()); + size_t min_size = min(NumValues(), rhs.NumValues()); if (tid < min_size) { - bits_[tid] |= rhs.bits_[tid]; + Data()[tid] |= rhs.Data()[tid]; } return *this; } #else BitFieldContainer& operator|=(BitFieldContainer const& rhs) { - size_t min_size = std::min(bits_.size(), rhs.bits_.size()); + size_t min_size = std::min(NumValues(), rhs.NumValues()); for (size_t i = 0; i < min_size; ++i) { - bits_[i] |= rhs.bits_[i]; + Data()[i] |= rhs.Data()[i]; } return *this; } @@ -121,75 +127,85 @@ struct BitFieldContainer { #if defined(__CUDA_ARCH__) __device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) { - size_t min_size = min(bits_.size(), rhs.bits_.size()); + size_t min_size = min(NumValues(), rhs.NumValues()); auto tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < min_size) { - bits_[tid] &= rhs.bits_[tid]; + Data()[tid] &= rhs.Data()[tid]; } return *this; } #else BitFieldContainer& operator&=(BitFieldContainer const& rhs) { - size_t min_size = std::min(bits_.size(), rhs.bits_.size()); + size_t min_size = std::min(NumValues(), rhs.NumValues()); for (size_t i = 0; i < min_size; ++i) { - bits_[i] &= rhs.bits_[i]; + Data()[i] &= rhs.Data()[i]; } return *this; } #endif // defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__) - __device__ auto Set(index_type pos) { + __device__ auto Set(index_type pos) noexcept(true) { Pos pos_v = Direction::Shift(ToBitPos(pos)); - value_type& value = bits_[pos_v.int_pos]; + value_type& value = Data()[pos_v.int_pos]; value_type set_bit = kOne << pos_v.bit_pos; using Type = typename dh::detail::AtomicDispatcher::Type; atomicOr(reinterpret_cast(&value), set_bit); } - __device__ void Clear(index_type pos) { + __device__ void Clear(index_type pos) noexcept(true) { Pos pos_v = Direction::Shift(ToBitPos(pos)); - value_type& value = bits_[pos_v.int_pos]; + value_type& value = Data()[pos_v.int_pos]; value_type clear_bit = ~(kOne << pos_v.bit_pos); using Type = typename dh::detail::AtomicDispatcher::Type; atomicAnd(reinterpret_cast(&value), clear_bit); } #else - void Set(index_type pos) { + void Set(index_type pos) noexcept(true) { Pos pos_v = Direction::Shift(ToBitPos(pos)); - value_type& value = bits_[pos_v.int_pos]; + value_type& value = Data()[pos_v.int_pos]; value_type set_bit = kOne << pos_v.bit_pos; value |= set_bit; } - void Clear(index_type pos) { + void Clear(index_type pos) noexcept(true) { Pos pos_v = Direction::Shift(ToBitPos(pos)); - value_type& value = bits_[pos_v.int_pos]; + value_type& value = Data()[pos_v.int_pos]; value_type clear_bit = ~(kOne << pos_v.bit_pos); value &= clear_bit; } #endif // defined(__CUDA_ARCH__) - XGBOOST_DEVICE bool Check(Pos pos_v) const { + XGBOOST_DEVICE bool Check(Pos pos_v) const noexcept(true) { pos_v = Direction::Shift(pos_v); - SPAN_LT(pos_v.int_pos, bits_.size()); - value_type const value = bits_[pos_v.int_pos]; + assert(pos_v.int_pos < NumValues()); + value_type const value = Data()[pos_v.int_pos]; value_type const test_bit = kOne << pos_v.bit_pos; value_type result = test_bit & value; return static_cast(result); } - XGBOOST_DEVICE bool Check(index_type pos) const { + [[nodiscard]] XGBOOST_DEVICE bool Check(index_type pos) const noexcept(true) { Pos pos_v = ToBitPos(pos); return Check(pos_v); } + /** + * @brief Returns the total number of bits that can be viewed. This is equal to or + * larger than the acutal number of valid bits. + */ + [[nodiscard]] XGBOOST_DEVICE size_type Capacity() const noexcept(true) { + return kValueSize * NumValues(); + } + /** + * @brief Number of storage unit used in this bit field. + */ + [[nodiscard]] XGBOOST_DEVICE size_type NumValues() const noexcept(true) { return n_values_; } - XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); } - - XGBOOST_DEVICE pointer Data() const { return bits_.data(); } + XGBOOST_DEVICE pointer Data() const noexcept(true) { return bits_; } - inline friend std::ostream & - operator<<(std::ostream &os, BitFieldContainer field) { - os << "Bits " << "storage size: " << field.bits_.size() << "\n"; - for (typename common::Span::index_type i = 0; i < field.bits_.size(); ++i) { - std::bitset::kValueSize> bset(field.bits_[i]); + inline friend std::ostream& operator<<(std::ostream& os, + BitFieldContainer field) { + os << "Bits " + << "storage size: " << field.NumValues() << "\n"; + for (typename common::Span::index_type i = 0; i < field.NumValues(); ++i) { + std::bitset::kValueSize> bset(field.Data()[i]); os << bset << "\n"; } return os; diff --git a/src/common/categorical.h b/src/common/categorical.h index d7e26281278f..249a818e5a42 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2020-2022 by XGBoost Contributors +/** + * Copyright 2020-2023, XGBoost Contributors * \file categorical.h */ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ @@ -10,7 +10,6 @@ #include "bitfield.h" #include "xgboost/base.h" #include "xgboost/data.h" -#include "xgboost/parameter.h" #include "xgboost/span.h" namespace xgboost { diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 6380952d7f61..d2edf2ec8e5f 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -84,7 +84,7 @@ class HistogramCuts { return *this; } - uint32_t FeatureBins(bst_feature_t feature) const { + [[nodiscard]] bst_bin_t FeatureBins(bst_feature_t feature) const { return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature]; } @@ -92,8 +92,8 @@ class HistogramCuts { std::vector const& Values() const { return cut_values_.ConstHostVector(); } std::vector const& MinValues() const { return min_vals_.ConstHostVector(); } - bool HasCategorical() const { return has_categorical_; } - float MaxCategory() const { return max_cat_; } + [[nodiscard]] bool HasCategorical() const { return has_categorical_; } + [[nodiscard]] float MaxCategory() const { return max_cat_; } /** * \brief Set meta info about categorical features. * @@ -105,12 +105,13 @@ class HistogramCuts { max_cat_ = max_cat; } - size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); } + [[nodiscard]] bst_bin_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); } // Return the index of a cut point that is strictly greater than the input // value, or the last available index if none exists - bst_bin_t SearchBin(float value, bst_feature_t column_id, std::vector const& ptrs, - std::vector const& values) const { + [[nodiscard]] bst_bin_t SearchBin(float value, bst_feature_t column_id, + std::vector const& ptrs, + std::vector const& values) const { auto end = ptrs[column_id + 1]; auto beg = ptrs[column_id]; auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); @@ -119,20 +120,20 @@ class HistogramCuts { return idx; } - bst_bin_t SearchBin(float value, bst_feature_t column_id) const { + [[nodiscard]] bst_bin_t SearchBin(float value, bst_feature_t column_id) const { return this->SearchBin(value, column_id, Ptrs(), Values()); } - /** * \brief Search the bin index for numerical feature. */ - bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); } + [[nodiscard]] bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); } /** * \brief Search the bin index for categorical feature. */ - bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector const& ptrs, - std::vector const& vals) const { + [[nodiscard]] bst_bin_t SearchCatBin(float value, bst_feature_t fidx, + std::vector const& ptrs, + std::vector const& vals) const { auto end = ptrs.at(fidx + 1) + vals.cbegin(); auto beg = ptrs[fidx] + vals.cbegin(); // Truncates the value in case it's not perfectly rounded. @@ -143,12 +144,14 @@ class HistogramCuts { } return bin_idx; } - bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { + [[nodiscard]] bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { auto const& ptrs = this->Ptrs(); auto const& vals = this->Values(); return this->SearchCatBin(value, fidx, ptrs, vals); } - bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); } + [[nodiscard]] bst_bin_t SearchCatBin(Entry const& e) const { + return SearchCatBin(e.fvalue, e.index); + } /** * \brief Return numerical bin value given bin index. diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 1b18f140aa67..bd66c2a53e70 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -590,7 +590,7 @@ class ArrayInterface { template void DispatchDType(ArrayInterface const array, std::int32_t device, Fn fn) { // Only used for cuDF at the moment. - CHECK_EQ(array.valid.Size(), 0); + CHECK_EQ(array.valid.Capacity(), 0); auto dispatch = [&](auto t) { using T = std::remove_const_t const; // Set the data size to max as we don't know the original size of a sliced array: diff --git a/src/data/data.cc b/src/data/data.cc index 00cff8ab0929..d305749eefb5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -416,7 +416,8 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::TensorReshape(array.shape); return; } - CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + CHECK_EQ(array.valid.Capacity(), 0) + << "Meta info like label or weight can not have missing value."; if (array.is_contiguous && array.type == ToDType::kType) { // Handle contigious p_out->ModifyInplace([&](HostDeviceVector* data, common::Span shape) { diff --git a/src/data/data.cu b/src/data/data.cu index eccbe7567193..0f1fda661500 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -33,7 +33,8 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens p_out->Reshape(array.shape); return; } - CHECK(array.valid.Size() == 0) << "Meta info like label or weight can not have missing value."; + CHECK_EQ(array.valid.Capacity(), 0) + << "Meta info like label or weight can not have missing value."; auto ptr_device = SetDeviceToPtr(array.data); p_out->SetDevice(ptr_device); diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index aa218fa31199..13fcf9adf8d6 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -5,6 +5,7 @@ #include #include "../common/categorical.h" +#include "../common/cuda_context.cuh" #include "../common/hist_util.cuh" #include "../common/random.h" #include "../common/transform_iterator.h" // MakeIndexTransformIter @@ -313,7 +314,8 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span auto d_csc_indptr = dh::ToSpan(csc_indptr); auto bin_type = page.index.GetBinTypeSize(); - common::CompressedBufferWriter writer{page.cut.TotalBins() + 1}; // +1 for null value + common::CompressedBufferWriter writer{page.cut.TotalBins() + + static_cast(1)}; // +1 for null value dh::LaunchN(row_stride * page.Size(), [=] __device__(size_t idx) mutable { auto ridx = idx / row_stride; @@ -357,8 +359,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag // copy gidx common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer(); - dh::device_vector row_ptr(page.row_ptr); + dh::device_vector row_ptr(page.row_ptr.size()); auto d_row_ptr = dh::ToSpan(row_ptr); + dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(), + cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream())); auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft); auto null = accessor.NullValue(); diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 02aa9a5c0e9e..d4324000f025 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -7,9 +7,6 @@ #ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ #define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ -#include -#include - #include #include #include @@ -20,35 +17,33 @@ #include "ellpack_page_source.h" #include "gradient_index_page_source.h" #include "sparse_page_source.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" -namespace xgboost { -namespace data { +namespace xgboost::data { /** * \brief DMatrix used for external memory. * * The external memory is created for controlling memory usage by splitting up data into - * multiple batches. However that doesn't mean we will actually process exact 1 batch at - * a time, which would be terribly slow considering that we have to loop through the - * whole dataset for every tree split. So we use async pre-fetch and let caller to decide - * how many batches it wants to process by returning data as shared pointer. The caller - * can use async function to process the data or just stage those batches, making the - * decision is out of the scope for sparse page dmatrix. These 2 optimizations might - * defeat the purpose of splitting up dataset since if you load all the batches then the - * memory usage is even worse than using a single batch. Essentially we need to control - * how many batches can be in memory at the same time. + * multiple batches. However that doesn't mean we will actually process exactly 1 batch + * at a time, which would be terribly slow considering that we have to loop through the + * whole dataset for every tree split. So we use async to pre-fetch pages and let the + * caller to decide how many batches it wants to process by returning data as a shared + * pointer. The caller can use async function to process the data or just stage those + * batches based on its use cases. These two optimizations might defeat the purpose of + * splitting up dataset since if you stage all the batches then the memory usage might be + * even worse than using a single batch. As a result, we must control how many batches can + * be in memory at any given time. * - * Right now the write to the cache is sequential operation and is blocking, reading from - * cache is async but with a hard coded limit of 4 pages as an heuristic. So by sparse - * dmatrix itself there can be only 9 pages in main memory (might be of different types) - * at the same time: 1 page pending for write, 4 pre-fetched sparse pages, 4 pre-fetched - * dependent pages. If the caller stops iteration at the middle and start again, then the - * number of pages in memory can hit 16 due to pre-fetching, but this should be a bug in - * caller's code (XGBoost doesn't discard a large portion of data at the end, there's not - * sampling algo that samples only the first portion of data). + * Right now the write to the cache is a sequential operation and is blocking. Reading + * from cache on ther other hand, is async but with a hard coded limit of 3 pages as an + * heuristic. So by sparse dmatrix itself there can be only 7 pages in main memory (might + * be of different types) at the same time: 1 page pending for write, 3 pre-fetched sparse + * pages, 3 pre-fetched dependent pages. * * Of course if the caller decides to retain some batches to perform parallel processing, * then we might load all pages in memory, which is also considered as a bug in caller's - * code. So if the algo supports external memory, it must be careful that queue for async + * code. So if the algo supports external memory, it must be careful that queue for async * call must have an upper limit. * * Another assumption we make is that the data must be immutable so caller should never @@ -101,7 +96,7 @@ class SparsePageDMatrix : public DMatrix { MetaInfo &Info() override; const MetaInfo &Info() const override; Context const *Ctx() const override { return &fmat_ctx_; } - + // The only DMatrix implementation that returns false. bool SingleColBlock() const override { return false; } DMatrix *Slice(common::Span) override { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; @@ -153,6 +148,5 @@ inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::st } return id; } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index b4e42f2db421..9f7bee521b09 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -6,39 +6,43 @@ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #include // for min -#include // async +#include // for async #include #include #include #include -#include +#include // for pair, move #include #include "../common/common.h" -#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP +#include "../common/io.h" // for PrivateMmapConstStream #include "../common/timer.h" // for Monitor, Timer #include "adapter.h" -#include "dmlc/common.h" // OMPException -#include "proxy_dmatrix.h" -#include "sparse_page_writer.h" +#include "dmlc/common.h" // for OMPException +#include "proxy_dmatrix.h" // for DMatrixProxy +#include "sparse_page_writer.h" // for SparsePageFormat #include "xgboost/base.h" #include "xgboost/data.h" namespace xgboost::data { inline void TryDeleteCacheFile(const std::string& file) { if (std::remove(file.c_str()) != 0) { + // Don't throw, this is called in a destructor. LOG(WARNING) << "Couldn't remove external memory cache file " << file << "; you may want to remove it manually"; } } +/** + * @brief Information about the cache including path and page offsets. + */ struct Cache { // whether the write to the cache is complete bool written; std::string name; std::string format; // offset into binary cache file. - std::vector offset; + std::vector offset; Cache(bool w, std::string n, std::string fmt) : written{w}, name{std::move(n)}, format{std::move(fmt)} { @@ -50,14 +54,24 @@ struct Cache { return name + format; } - std::string ShardName() { + [[nodiscard]] std::string ShardName() const { return ShardName(this->name, this->format); } - void Push(std::size_t n_bytes) { - offset.push_back(n_bytes); + /** + * @brief Record a page with size of n_bytes. + */ + void Push(std::size_t n_bytes) { offset.push_back(n_bytes); } + /** + * @brief Returns the view start and length for the i^th page. + */ + [[nodiscard]] auto View(std::size_t i) const { + std::uint64_t off = offset.at(i); + std::uint64_t len = offset.at(i + 1) - offset[i]; + return std::pair{off, len}; } - - // The write is completed. + /** + * @brief Call this once the write for the cache is complete. + */ void Commit() { if (!written) { std::partial_sum(offset.begin(), offset.end(), offset.begin()); @@ -66,7 +80,7 @@ struct Cache { } }; -// Prevents multi-threaded call. +// Prevents multi-threaded call to `GetBatches`. class TryLockGuard { std::mutex& lock_; @@ -79,22 +93,25 @@ class TryLockGuard { } }; +/** + * @brief Base class for all page sources. Handles fetching, writing, and iteration. + */ template class SparsePageSourceImpl : public BatchIteratorImpl { protected: // Prevents calling this iterator from multiple places(or threads). std::mutex single_threaded_; - + // The current page. std::shared_ptr page_; bool at_end_ {false}; float missing_; - int nthreads_; + std::int32_t nthreads_; bst_feature_t n_features_; - - uint32_t count_{0}; - - uint32_t n_batches_ {0}; + // Index to the current page. + std::uint32_t count_{0}; + // Total number of batches. + std::uint32_t n_batches_{0}; std::shared_ptr cache_info_; @@ -102,6 +119,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl { // A ring storing futures to data. Since the DMatrix iterator is forward only, so we // can pre-fetch data in a ring. std::unique_ptr ring_{new Ring}; + // Catching exception in pre-fetch threads to prevent segfault. Not always work though, + // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then + // OOM error should be rare. dmlc::OMPException exec_; common::Monitor monitor_; @@ -123,7 +143,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl { exec_.Rethrow(); - monitor_.Start("launch"); for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { @@ -134,33 +153,25 @@ class SparsePageSourceImpl : public BatchIteratorImpl { ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { auto page = std::make_shared(); this->exec_.Run([&] { - common::Timer timer; - timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; - auto n = self->cache_info_->ShardName(); - - std::uint64_t offset = self->cache_info_->offset.at(fetch_it); - std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset; - - auto fi = std::make_unique(n, offset, length); + auto name = self->cache_info_->ShardName(); + auto [offset, length] = self->cache_info_->View(fetch_it); + auto fi = std::make_unique(name, offset, length); CHECK(fmt->Read(page.get(), fi.get())); - timer.Stop(); - - LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds() - << " seconds."; }); return page; }); } - monitor_.Stop("launch"); CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; + monitor_.Start("Wait"); page_ = (*ring_)[count_].get(); - monitor_.Stop("Wait"); CHECK(!(*ring_)[count_].valid()); + monitor_.Stop("Wait"); + exec_.Rethrow(); return true; @@ -183,6 +194,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { auto bytes = fmt->Write(*page_, fo.get()); timer.Stop(); + // Not entirely accurate, the kernels doesn't have to flush the data. LOG(INFO) << static_cast(bytes) / 1024.0 / 1024.0 << " MB written in " << timer.ElapsedSeconds() << " seconds."; cache_info_->Push(bytes); @@ -204,6 +216,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete; ~SparsePageSourceImpl() override { + // Don't orphan the threads. for (auto& fu : *ring_) { if (fu.valid()) { fu.get(); @@ -211,18 +224,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } } - uint32_t Iter() const { return count_; } + [[nodiscard]] uint32_t Iter() const { return count_; } const S &operator*() const override { CHECK(page_); return *page_; } - std::shared_ptr Page() const override { + [[nodiscard]] std::shared_ptr Page() const override { return page_; } - bool AtEnd() const override { + [[nodiscard]] bool AtEnd() const override { return at_end_; } @@ -230,20 +243,23 @@ class SparsePageSourceImpl : public BatchIteratorImpl { TryLockGuard guard{single_threaded_}; at_end_ = false; count_ = 0; + // Pre-fetch for the next round of iterations. this->Fetch(); } }; #if defined(XGBOOST_USE_CUDA) +// Push data from CUDA. void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page); #else inline void DevicePush(DMatrixProxy*, float, SparsePage*) { common::AssertGPUSupport(); } #endif class SparsePageSource : public SparsePageSourceImpl { + // This is the source from the user. DataIterProxy iter_; DMatrixProxy* proxy_; - size_t base_row_id_ {0}; + std::size_t base_row_id_{0}; void Fetch() final { page_ = std::make_shared(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 4b834e78fb90..98e38068239f 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -439,7 +439,7 @@ struct ShapSplitCondition { if (isnan(x)) { return is_missing_branch; } - if (categories.Size() != 0) { + if (categories.Capacity() != 0) { auto cat = static_cast(x); return categories.Check(cat); } else { @@ -454,7 +454,7 @@ struct ShapSplitCondition { if (l.Data() == r.Data()) { return l; } - if (l.Size() > r.Size()) { + if (l.Capacity() > r.Capacity()) { thrust::swap(l, r); } for (size_t i = 0; i < r.Bits().size(); ++i) { @@ -466,7 +466,7 @@ struct ShapSplitCondition { // Combine two split conditions on the same feature XGBOOST_DEVICE void Merge(ShapSplitCondition other) { // Combine duplicate features - if (categories.Size() != 0 || other.categories.Size() != 0) { + if (categories.Capacity() != 0 || other.categories.Capacity() != 0) { categories = Intersect(categories, other.categories); } else { feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu index b6db0eda0739..ae1d3073c7cc 100644 --- a/src/tree/constraints.cu +++ b/src/tree/constraints.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 XGBoost contributors +/** + * Copyright 2019-2023, XGBoost contributors */ #include #include @@ -140,20 +140,20 @@ void FeatureInteractionConstraintDevice::Reset() { __global__ void ClearBuffersKernel( LBitField64 result_buffer_output, LBitField64 result_buffer_input) { auto tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < result_buffer_output.Size()) { + if (tid < result_buffer_output.Capacity()) { result_buffer_output.Clear(tid); } - if (tid < result_buffer_input.Size()) { + if (tid < result_buffer_input.Capacity()) { result_buffer_input.Clear(tid); } } void FeatureInteractionConstraintDevice::ClearBuffers() { - CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size()); - CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); + CHECK_EQ(output_buffer_bits_.Capacity(), input_buffer_bits_.Capacity()); + CHECK_LE(feature_buffer_.Capacity(), output_buffer_bits_.Capacity()); uint32_t constexpr kBlockThreads = 256; auto const n_grids = static_cast( - common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); + common::DivRoundUp(input_buffer_bits_.Capacity(), kBlockThreads)); dh::LaunchKernel {n_grids, kBlockThreads} ( ClearBuffersKernel, output_buffer_bits_, input_buffer_bits_); @@ -207,11 +207,11 @@ common::Span FeatureInteractionConstraintDevice::Query( ClearBuffers(); LBitField64 node_constraints = s_node_constraints_[nid]; - CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); + CHECK_EQ(input_buffer_bits_.Capacity(), output_buffer_bits_.Capacity()); uint32_t constexpr kBlockThreads = 256; auto n_grids = static_cast( - common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); + common::DivRoundUp(output_buffer_bits_.Capacity(), kBlockThreads)); dh::LaunchKernel {n_grids, kBlockThreads} ( SetInputBufferKernel, feature_list, input_buffer_bits_); @@ -274,13 +274,13 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature, LBitField64 left, LBitField64 right) { auto tid = threadIdx.x + blockDim.x * blockIdx.x; - if (tid > node.Size()) { + if (tid > node.Capacity()) { return; } // enable constraints from feature node |= feature; // clear the buffer after use - if (tid < feature.Size()) { + if (tid < feature.Capacity()) { feature.Clear(tid); } @@ -323,7 +323,7 @@ void FeatureInteractionConstraintDevice::Split( s_sets_, s_sets_ptr_); uint32_t constexpr kBlockThreads = 256; - auto n_grids = static_cast(common::DivRoundUp(node.Size(), kBlockThreads)); + auto n_grids = static_cast(common::DivRoundUp(node.Capacity(), kBlockThreads)); dh::LaunchKernel {n_grids, kBlockThreads} ( InteractionConstraintSplitKernel, diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 7550904b5753..f32ea701f3a5 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -213,7 +213,7 @@ std::vector GetSplitCategories(RegTree const &tree, int32_t nidx) { auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)}; std::vector cats; - for (size_t i = 0; i < split.Size(); ++i) { + for (size_t i = 0; i < split.Capacity(); ++i) { if (split.Check(i)) { cats.push_back(static_cast(i)); } @@ -1004,7 +1004,7 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const { auto segment = split_categories_segments_[i]; auto node_categories = this->GetSplitCategories().subspan(segment.beg, segment.size); common::KCatBitField const cat_bits(node_categories); - for (size_t i = 0; i < cat_bits.Size(); ++i) { + for (size_t i = 0; i < cat_bits.Capacity(); ++i) { if (cat_bits.Check(i)) { categories.GetArray().emplace_back(i); } diff --git a/tests/cpp/common/test_bitfield.cc b/tests/cpp/common/test_bitfield.cc index c7b2d5cb9cea..902e69f85ad8 100644 --- a/tests/cpp/common/test_bitfield.cc +++ b/tests/cpp/common/test_bitfield.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 XGBoost contributors +/** + * Copyright 2019-2023, XGBoost contributors */ #include #include "../../../src/common/bitfield.h" @@ -14,7 +14,7 @@ TEST(BitField, Check) { static_cast::index_type>( storage.size())}); size_t true_bit = 190; - for (size_t i = true_bit + 1; i < bits.Size(); ++i) { + for (size_t i = true_bit + 1; i < bits.Capacity(); ++i) { ASSERT_FALSE(bits.Check(i)); } ASSERT_TRUE(bits.Check(true_bit)); @@ -34,7 +34,7 @@ TEST(BitField, Check) { ASSERT_FALSE(bits.Check(i)); } ASSERT_TRUE(bits.Check(true_bit)); - for (size_t i = true_bit + 1; i < bits.Size(); ++i) { + for (size_t i = true_bit + 1; i < bits.Capacity(); ++i) { ASSERT_FALSE(bits.Check(i)); } } diff --git a/tests/cpp/common/test_bitfield.cu b/tests/cpp/common/test_bitfield.cu index 98fbd2ad10d2..a9b183c43740 100644 --- a/tests/cpp/common/test_bitfield.cu +++ b/tests/cpp/common/test_bitfield.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 XGBoost contributors +/** + * Copyright 2019-2023, XGBoost contributors */ #include #include @@ -12,7 +12,7 @@ namespace xgboost { __global__ void TestSetKernel(LBitField64 bits) { auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid < bits.Size()) { + if (tid < bits.Capacity()) { bits.Set(tid); } } @@ -36,20 +36,16 @@ TEST(BitField, GPUSet) { std::vector h_storage(storage.size()); thrust::copy(storage.begin(), storage.end(), h_storage.begin()); - - LBitField64 outputs { - common::Span{h_storage.data(), - h_storage.data() + h_storage.size()}}; + LBitField64 outputs{ + common::Span{h_storage.data(), h_storage.data() + h_storage.size()}}; for (size_t i = 0; i < kBits; ++i) { ASSERT_TRUE(outputs.Check(i)); } } -__global__ void TestOrKernel(LBitField64 lhs, LBitField64 rhs) { - lhs |= rhs; -} - -TEST(BitField, GPUAnd) { +namespace { +template +void TestGPULogic(Op op) { uint32_t constexpr kBits = 128; dh::device_vector lhs_storage(kBits); dh::device_vector rhs_storage(kBits); @@ -57,13 +53,32 @@ TEST(BitField, GPUAnd) { auto rhs = LBitField64(dh::ToSpan(rhs_storage)); thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL); thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast(0UL)); - TestOrKernel<<<1, kBits>>>(lhs, rhs); + dh::LaunchN(kBits, [=] __device__(auto) mutable { op(lhs, rhs); }); std::vector h_storage(lhs_storage.size()); thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin()); - LBitField64 outputs {{h_storage.data(), h_storage.data() + h_storage.size()}}; - for (size_t i = 0; i < kBits; ++i) { - ASSERT_TRUE(outputs.Check(i)); + LBitField64 outputs{{h_storage.data(), h_storage.data() + h_storage.size()}}; + if (is_and) { + for (size_t i = 0; i < kBits; ++i) { + ASSERT_FALSE(outputs.Check(i)); + } + } else { + for (size_t i = 0; i < kBits; ++i) { + ASSERT_TRUE(outputs.Check(i)); + } } } -} // namespace xgboost \ No newline at end of file + +void TestGPUAnd() { + TestGPULogic([] XGBOOST_DEVICE(LBitField64 & lhs, LBitField64 const& rhs) { lhs &= rhs; }); +} + +void TestGPUOr() { + TestGPULogic([] XGBOOST_DEVICE(LBitField64 & lhs, LBitField64 const& rhs) { lhs |= rhs; }); +} +} // namespace + +TEST(BitField, GPUAnd) { TestGPUAnd(); } + +TEST(BitField, GPUOr) { TestGPUOr(); } +} // namespace xgboost diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 0578683d8f1d..8b8df48612bd 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -83,7 +83,9 @@ template void CheckColumWithMissingValue(const DenseColumnIter& col, const GHistIndexMatrix& gmat) { for (auto i = 0ull; i < col.Size(); i++) { - if (col.IsMissing(i)) continue; + if (col.IsMissing(i)) { + continue; + } EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i)); } } diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index f1317fc02511..cb2f7d604ea5 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -285,8 +285,6 @@ TEST(GpuHist, PartitionTwoNodes) { dh::ToSpan(feature_histogram_b)}; thrust::device_vector results(2); evaluator.EvaluateSplits({0, 1}, 1, dh::ToSpan(inputs), shared_inputs, dh::ToSpan(results)); - GPUExpandEntry result_a = results[0]; - GPUExpandEntry result_b = results[1]; EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(0)[0]), std::bitset<32>("10000000000000000000000000000000")); EXPECT_EQ(std::bitset<32>(evaluator.GetHostNodeCats(1)[0]), diff --git a/tests/cpp/tree/test_constraints.cu b/tests/cpp/tree/test_constraints.cu index c9f1639b30c2..09e72a1d2bfa 100644 --- a/tests/cpp/tree/test_constraints.cu +++ b/tests/cpp/tree/test_constraints.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 XGBoost contributors +/** + * Copyright 2019-2023, XGBoost contributors */ #include #include @@ -53,7 +53,7 @@ void CompareBitField(LBitField64 d_field, std::set positions) { LBitField64 h_field{ {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()} }; - for (size_t i = 0; i < h_field.Size(); ++i) { + for (size_t i = 0; i < h_field.Capacity(); ++i) { if (positions.find(i) != positions.cend()) { ASSERT_TRUE(h_field.Check(i)); } else { @@ -82,7 +82,7 @@ TEST(GPUFeatureInteractionConstraint, Init) { {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()} }; // no feature is attached to node. - for (size_t i = 0; i < h_node.Size(); ++i) { + for (size_t i = 0; i < h_node.Capacity(); ++i) { ASSERT_FALSE(h_node.Check(i)); } }