Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multiple batches in gpu_hist #5014

Merged
merged 30 commits into from
Nov 16, 2019
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4e0824e
get rid of BinCount() method
rongou Oct 28, 2019
f5160cc
pass dmatrix to GPUHistMakerDevice
rongou Oct 29, 2019
397301e
reset row partitioner to n_rows on a page
rongou Oct 29, 2019
ca7f132
Revert "pass dmatrix to GPUHistMakerDevice"
rongou Oct 30, 2019
57121bf
get rid of the todo
rongou Oct 30, 2019
aae1e8f
remove redundant code
rongou Oct 30, 2019
48fabd6
add page size to BatchParam
rongou Oct 30, 2019
a299db1
test multiple ellpack pages
rongou Oct 30, 2019
4365625
add failing test for gpu_hist in external memory mode
rongou Oct 30, 2019
f9f0b7e
handle multiple batches in InitRoot
rongou Oct 31, 2019
20fef95
support multiple batches in gpu_hist
rongou Oct 31, 2019
0b48ab6
debugging failing test
rongou Nov 4, 2019
bb6afd8
add tests for ellpack page content
rongou Nov 5, 2019
3ead65d
add tests for ellpack content
rongou Nov 5, 2019
f0f8b54
test looping through ellpack pages multiple times
rongou Nov 5, 2019
68d08c5
tests passing
rongou Nov 5, 2019
c6b8e8a
fix clang tidy warning
rongou Nov 5, 2019
70e424e
make the ellpack tests more forgiving
rongou Nov 6, 2019
a6cca35
Merge branch 'master' into gpu-hist-batches
rongou Nov 6, 2019
8d8b426
move base_rowid into EllpackMatrix
rongou Nov 6, 2019
14734d9
change row partitioner back to absolute row ids
rongou Nov 7, 2019
2a3c02a
Merge branch 'master' into gpu-hist-batches
rongou Nov 11, 2019
e86edc1
actually verify every row
rongou Nov 11, 2019
5a75451
add a libsvm generator
rongou Nov 12, 2019
0662f8d
libsvm is 0-based
rongou Nov 13, 2019
a894c36
Merge branch 'master' into gpu-hist-batches
rongou Nov 13, 2019
73eec1a
Merge branch 'master' into gpu-hist-batches
trivialfis Nov 14, 2019
3b716ce
Fix merge conflict.
trivialfis Nov 14, 2019
377d043
fix a few issues
rongou Nov 15, 2019
a6bf0dd
minor formatting
rongou Nov 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ struct BatchParam {
int max_bin;
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we we expose this to users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. See below.


inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
gpu_batch_nrows != other.gpu_batch_nrows ||
gpu_page_size != other.gpu_page_size;
}
};

/*!
Expand Down
6 changes: 6 additions & 0 deletions include/xgboost/generic_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
int nthread;
// primary device, -1 means no gpu.
int gpu_id;
// gpu page size in external memory mode, 0 means using the default.
size_t gpu_page_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter is configurable by users, please don't define it twice. Make one of them a normal variable. If we don't want to configure it by user, don't use parameter at all. As pickling might loss some of these information, dask uses pickle to move booster around workers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only defined as a configurable parameter once here, the other one is really just plumbing. For now this is mostly used for testing, but perhaps user may want to set it depending on the GPU memory they have.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Em.. we want to do parameter validation, like detecting unused parameters. This may add some extra difficulties. Do you think it's possible to set it as a DMatrix parameter instead of a global one? Maybe another PR? Sorry for nitpicking here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we need to be careful adding global parameters due to upcoming work on serialisation. Unless you see a strong motivation for users tuning this, let's leave it out for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure if this is useful for end users. Is there way to make a parameter hidden/internal? It's really useful for the tests since we don't have to build a dataset bigger than 32MB.


void CheckDeprecated() {
if (this->n_gpus != 0) {
Expand Down Expand Up @@ -49,6 +51,10 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.set_default(-1)
.set_lower_bound(-1)
.describe("The primary GPU device ordinal.");
DMLC_DECLARE_FIELD(gpu_page_size)
.set_default(0)
.set_lower_bound(0)
.describe("GPU page size when running in external memory mode.");
DMLC_DECLARE_FIELD(n_gpus)
.set_default(0)
.set_range(0, 1)
Expand Down
45 changes: 26 additions & 19 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(param.gpu_id));

matrix.n_rows = dmat->Info().num_row_;

monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
Expand Down Expand Up @@ -206,52 +208,57 @@ void EllpackPageImpl::CreateHistIndices(int device,

// Return the number of rows contained in this page.
size_t EllpackPageImpl::Size() const {
return n_rows;
return matrix.n_rows;
}

// Clear the current page.
void EllpackPageImpl::Clear() {
ba_.Clear();
gidx_buffer = {};
idx_buffer.clear();
n_rows = 0;
sparse_page_.Clear();
matrix.base_rowid = 0;
matrix.n_rows = 0;
}

// Push a CSR page to the current page.
//
// First compress the CSR page into ELLPACK, then the compressed buffer is copied to host and
// appended to the existing host vector.
// The CSR pages are accumulated in memory until they reach a certain size, then written out as
// compressed ELLPACK.
void EllpackPageImpl::Push(int device, const SparsePage& batch) {
sparse_page_.Push(batch);
matrix.n_rows += batch.Size();
}

// Compress the accumulated SparsePage.
void EllpackPageImpl::CompressSparsePage(int device) {
monitor_.StartCuda("InitCompressedData");
InitCompressedData(device, batch.Size());
InitCompressedData(device, matrix.n_rows);
monitor_.StopCuda("InitCompressedData");

monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(batch.Size());
hist_builder_row_state.BeginBatch(batch);
CreateHistIndices(device, batch, hist_builder_row_state.GetRowStateOnDevice());
DeviceHistogramBuilderState hist_builder_row_state(matrix.n_rows);
hist_builder_row_state.BeginBatch(sparse_page_);
CreateHistIndices(device, sparse_page_, hist_builder_row_state.GetRowStateOnDevice());
hist_builder_row_state.EndBatch();
monitor_.StopCuda("BinningCompression");

monitor_.StartCuda("CopyDeviceToHost");
std::vector<common::CompressedByteT> buffer(gidx_buffer.size());
dh::CopyDeviceSpanToVector(&buffer, gidx_buffer);
int offset = 0;
if (!idx_buffer.empty()) {
offset = ::xgboost::common::detail::kPadding;
}
idx_buffer.reserve(idx_buffer.size() + buffer.size() - offset);
idx_buffer.insert(idx_buffer.end(), buffer.begin() + offset, buffer.end());
idx_buffer.resize(gidx_buffer.size());
dh::CopyDeviceSpanToVector(&idx_buffer, gidx_buffer);
ba_.Clear();
gidx_buffer = {};
monitor_.StopCuda("CopyDeviceToHost");

n_rows += batch.Size();
}

// Return the memory cost for storing the compressed features.
size_t EllpackPageImpl::MemCostBytes() const {
return idx_buffer.size() * sizeof(common::CompressedByteT);
size_t num_symbols = matrix.info.n_bins + 1;

// Required buffer size for storing data matrix in ELLPack format.
size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize(
matrix.info.row_stride * matrix.n_rows, num_symbols);
return compressed_size_bytes;
}

// Copy the compressed features to GPU.
Expand Down
21 changes: 16 additions & 5 deletions src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,14 @@ struct EllpackInfo {
* kernels.*/
struct EllpackMatrix {
EllpackInfo info;
size_t base_rowid{};
size_t n_rows{};
common::CompressedIterator<uint32_t> gidx_iter;

XGBOOST_DEVICE size_t BinCount() const { return info.gidx_fvalue_map.size(); }

// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride;
auto gidx = -1;
Expand All @@ -102,6 +103,11 @@ struct EllpackMatrix {
}
return info.gidx_fvalue_map[gidx];
}

// Check if the row id is withing range of the current batch.
__device__ bool IsInRange(size_t row_id) const {
return row_id >= base_rowid && row_id < base_rowid + n_rows;
}
};

// Instances of this type are created while creating the histogram bins for the
Expand Down Expand Up @@ -185,7 +191,6 @@ class EllpackPageImpl {
/*! \brief global index of histogram, which is stored in ELLPack format. */
common::Span<common::CompressedByteT> gidx_buffer;
std::vector<common::CompressedByteT> idx_buffer;
size_t n_rows{};

/*!
* \brief Default constructor.
Expand Down Expand Up @@ -240,7 +245,7 @@ class EllpackPageImpl {

/*! \brief Set the base row id for this page. */
inline void SetBaseRowId(size_t row_id) {
base_rowid_ = row_id;
matrix.base_rowid = row_id;
}

/*! \brief clear the page. */
Expand All @@ -263,11 +268,17 @@ class EllpackPageImpl {
*/
void InitDevice(int device, EllpackInfo info);

/*! \brief Compress the accumulated SparsePage into ELLPACK format.
*
* @param device The GPU device to use.
*/
void CompressSparsePage(int device);

private:
common::Monitor monitor_;
dh::BulkAllocator ba_;
size_t base_rowid_{};
bool device_initialized_{false};
SparsePage sparse_page_{};
};

} // namespace xgboost
Expand Down
6 changes: 3 additions & 3 deletions src/data/ellpack_page_raw_format.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
public:
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
auto* impl = page->Impl();
if (!fi->Read(&impl->n_rows)) return false;
if (!fi->Read(&impl->matrix.n_rows)) return false;
return fi->Read(&impl->idx_buffer);
}

bool Read(EllpackPage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) override {
auto* impl = page->Impl();
if (!fi->Read(&impl->n_rows)) return false;
if (!fi->Read(&impl->matrix.n_rows)) return false;
return fi->Read(&page->Impl()->idx_buffer);
}

void Write(const EllpackPage& page, dmlc::Stream* fo) override {
auto* impl = page.Impl();
fo->Write(impl->n_rows);
fo->Write(impl->matrix.n_rows);
auto buffer = impl->idx_buffer;
CHECK(!buffer.empty());
fo->Write(buffer);
Expand Down
22 changes: 16 additions & 6 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ class EllpackPageSourceImpl : public DataSource<EllpackPage> {
const std::string kPageType_{".ellpack.page"};

int device_{-1};
size_t page_size_{DMatrix::kPageSize};
common::Monitor monitor_;
dh::BulkAllocator ba_;
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
EllpackInfo ellpack_info_;
std::unique_ptr<SparsePageSource<EllpackPage>> source_;
std::string cache_info_;
};

EllpackPageSource::EllpackPageSource(DMatrix* dmat,
Expand Down Expand Up @@ -72,8 +74,12 @@ const EllpackPage& EllpackPageSource::Value() const {
// each CSR page, and write the accumulated ELLPACK pages to disk.
EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
const std::string& cache_info,
const BatchParam& param) noexcept(false) {
device_ = param.gpu_id;
const BatchParam& param) noexcept(false)
: device_(param.gpu_id), cache_info_(cache_info) {

if (param.gpu_page_size > 0) {
page_size_ = param.gpu_page_size;
}

monitor_.Init("ellpack_page_source");
dh::safe_cuda(cudaSetDevice(device_));
Expand All @@ -92,10 +98,11 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
WriteEllpackPages(dmat, cache_info);
monitor_.StopCuda("WriteEllpackPages");

source_.reset(new SparsePageSource<EllpackPage>(cache_info, kPageType_));
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
}

void EllpackPageSourceImpl::BeforeFirst() {
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
source_->BeforeFirst();
}

Expand Down Expand Up @@ -133,20 +140,23 @@ void EllpackPageSourceImpl::WriteEllpackPages(DMatrix* dmat, const std::string&
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
impl->Push(device_, batch);

if (impl->MemCostBytes() >= DMatrix::kPageSize) {
bytes_write += impl->MemCostBytes();
size_t mem_cost_bytes = impl->MemCostBytes();
if (mem_cost_bytes >= page_size_) {
bytes_write += mem_cost_bytes;
impl->CompressSparsePage(device_);
writer.PushWrite(std::move(page));
writer.Alloc(&page);
impl = page->Impl();
impl->matrix.info = ellpack_info_;
impl->Clear();
double tdiff = dmlc::GetTime() - tstart;
LOG(INFO) << "Writing to " << cache_info << " in "
LOG(INFO) << "Writing " << kPageType_ << " to " << cache_info << " in "
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
<< (bytes_write >> 20UL) << " written";
}
}
if (impl->Size() != 0) {
impl->CompressSparsePage(device_);
writer.PushWrite(std::move(page));
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// Lazily instantiate
if (!ellpack_source_ ||
batch_param_.gpu_id != param.gpu_id ||
batch_param_.max_bin != param.max_bin ||
batch_param_.gpu_batch_nrows != param.gpu_batch_nrows) {
if (!ellpack_source_ || batch_param_ != param) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param;
}
Expand Down
6 changes: 5 additions & 1 deletion src/tree/gpu_hist/row_partitioner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RowPartitioner {
using TreePositionT = int32_t;
using RowIndexT = bst_uint;
struct Segment;
static constexpr TreePositionT kIgnoredTreePosition = -1;

private:
int device_idx;
Expand Down Expand Up @@ -124,6 +125,7 @@ class RowPartitioner {
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
TreePositionT new_position = op(ridx); // new node id
if (new_position == kIgnoredTreePosition) return;
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;
Expand Down Expand Up @@ -163,7 +165,9 @@ class RowPartitioner {
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
d_position[idx] = op(ridx, position);
TreePositionT new_position = op(ridx, position);
if (new_position == kIgnoredTreePosition) return;
d_position[idx] = new_position;
});
}

Expand Down
Loading