-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] External memory support for GPU #4357
Comments
Adding a task list here to keep track of the progress:
|
@rongou Sounds awesome. Could you add some details around the preprocess step? I have been looking in sketch recently. |
What I'm thinking is more or less a pure refactoring. Right now we hand off the sparse page dmatrix to the tree updater, which loops through it once to build the quantiles, then loop through it again to compress each batch. We probably need to do these at a higher level so that we can write out the compressed features. Hand-wavy pseudo-code:
Will try to get a draft PR out next week. |
@rongou Preferably this is done in DMatrix, you can add a new method to DMatrix called get compressed hist index ... |
@trivialfis isn't dmatrix just a data representation type and isn't it agnostic to tree methods such as histograms? if so, is it right to pollute this interface with such methods? |
@sriramch We think building histogram indices inside DMatrix can provide us some opportunities around saving memory. Like it's possible to avoid copying original data set. To me it makes sense for us to consider histogram indices as data, since histogram is used to replace the actual input data. Also this is inspired by LGB. I think @hcho3 and @RAMitchell can provide some more input for other issues of current indices building method. We can talk about it in details in the DMatrix RFC, since I'm not sure about how to integrate it with external memory yet. |
I guess it depends on your definition of "support". :) It now supports the |
Indeed. Let's keep the issue opened then. Thanks.
…________________________________
From: Rong Ou <[email protected]>
Sent: Thursday, August 8, 2019 11:33:12 PM
To: dmlc/xgboost <[email protected]>
Cc: Jiaming Yuan <[email protected]>; Mention <[email protected]>
Subject: Re: [dmlc/xgboost] [RFC] External memory support for GPU (#4357)
I guess it depends on your definition of "support". :) It now supports the # cache file syntax, but the whole dataset is still kept in GPU memory.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub<#4357?email_source=notifications&email_token=AD7YPKNYTU4WL5SYTM5D67TQDUFSRA5CNFSM4HFBS5HKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD35XTKA#issuecomment-519797160>, or mute the thread<https://github.com/notifications/unsubscribe-auth/AD7YPKIKOMWTHN2DYV6DXVDQDUFSRANCNFSM4HFBS5HA>.
|
Would the https://arrow.apache.org/docs/python/cuda.html Has it been considered? |
By off-heap do you mean reading from disk? Maybe in the future, but currently not a priority.
Did you mean running XGBoost on dataset with arrow memory layout without making a copy? If so then no in near future. Internally XGBoost uses CSR (for most of the parts). They are quite different. Having said that, it's still possible I think.
I'm interested in this feature. But again, so many to-dos ...
I have a PR for initial support of cuDF. I think the arrow specification still has some room to grow, let's see what can be done after merging that PR. Hope that helps. |
@SemanticBeeng If you have some ideas, please do share. |
Thanks.
Am looking to determine if the goal of "standardizing
Ah, yes, this is important to know. Do you know if |
I use XGBoost 0.81 with cpu, The external memory that allows for training on datasets that don’t fit in main (a.k.a. host) memory. But I train ,the memory is same 17G, use the external memory or not. |
Document should be the last item now |
External memory support for GPU and gradient based sampling is available! Thanks to @rongou |
Motivation
XGBoost has experimental support for external memory that allows for training on datasets that don’t fit in main (a.k.a. host) memory. However, it’s not available with GPU algorithms. To train on large datasets with GPUs, we have to either downsample the data, which defeats the purpose, or scale out to multi-gpu and multi-node settings, which have their own complexities and limitations. It’s desirable to make the GPU algorithms more flexible by adding external memory support.
Goals
gpu_hist
tree method.Non-Goals
This proposal is for XGBoost only, and doesn’t apply to other libraries in RAPIDS such as cuDF/cuML.
Assumptions
libsvm
format.Risks
Design
Existing Code
The current XGBoost code allows the user to specify the input data as
where
filename
is the normal path to alibsvm
file, andcacheprefix
is a path to a cache file that XGBoost will use for external memory cache. During training, the data is read in, parsed, and written out to the cache in 32MB pages. For each round, the pages are read back in and fed to the learner as batches.Currently the
gpu_hist
tree method only accepts the first batch and errors out if there are more than one batch. In the training loop, the data is initialized once by going through feature quantile generation and data compression. The compressed feature quantiles are kept in GPU memory during training.The GPU algorithm is described in detail in the paper.
Adding external memory support can be split into the following phases.
Phase 1: Basic Correctness
Here we aim to support external memory in the simplest way:
SparsePage
within aDMatrix
, the algorithm should iterate over allSparsePage
s. One existing limitation is that only training/prediction data is read from external memory, while predictions are kept fully in memory.Phase 2: Cache Compressed Features
Based on current timing information, feature quantile generation and data compression are relatively expensive (equivalent to building around 40 trees), and will add up quickly if we perform them on the fly for every batch. We can write the compressed features back to disk, similar to how data caching is currently done for external memory. Compressed and binned features take less space (~12 bits after compression and binning vs 64 bits before), and are therefore cheaper to read. This is tied to the
DMatrix
refacotring (#4354).In the multi-gpu setup, right now each data batch is split between different GPUs. When we properly support multiple batches, we can instead split at the batch level and feed whole batches to each GPU.
Phase 3: Overlap Data Transfers and Tree Construction
Regardless of the type of storage used to cache the compressed features, it’s likely to be very slow to feed each batch to the GPU and then build the histograms serially. We need to put the data transfer and compute for each batch into a separate CUDA stream, and launch multiple streams asynchronously.
Phase 4: (Optionally) Use CUDA Graphs
We can potentially define the data transfers and tree construction in a CUDA graph and launch the graph repeatedly for each batch. This is only available for CUDA 10 and above.
Alternatives Considered
When training on a very large dataset (e.g. 10-20 TB), one approach is to scale out to many nodes. However, it may require hundreds or even thousands of GPUs, which may be prohibitively expensive and hard to schedule in a shared cluster. With that many nodes, network communication may also become a bottleneck, especially in a cloud/enterprise environment relying on relatively slow Ethernet connections. The goal of external memory support is not to replace distributed training, but to provide an alternative that may be better suited to some situations.
Another approach is stochastic gradient boosting: at each iteration a subsample of the training data is drawn at random (without replacement) from the full training data set. This randomly selected subsample is then used in place of the full sample to fit the base learner and compute the model update for the current iteration (see paper). However, the sampling ratio is a hyperparameter that needs to be tuned, so the subsample may still not fit in memory. There are other sampling techniques (for example, see paper), which are algorithmic improvements largely orthogonal to external memory support.
@RAMitchell @canonizer @sriramch
The text was updated successfully, but these errors were encountered: