Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FIL code readability and documentation #3056

Merged
merged 19 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
93774eb
hide inheritance structure and clean up comments
levsnv Oct 23, 2020
4423872
Merge remote-tracking branch 'rapidsai/branch-0.17' into neater-code
levsnv Nov 12, 2020
73a43db
clearer __syncthreads() purpose
levsnv Nov 12, 2020
da8529f
simplify treelite.Model handle extraction
levsnv Nov 14, 2020
1e00e64
removed deprecated "silent" from xgboost params; repeated code
levsnv Nov 14, 2020
9e7463c
moved all_set closer to where it gets affected
levsnv Nov 14, 2020
aa002fc
fixed old bug
levsnv Nov 14, 2020
1db94e5
Revert "simplify treelite.Model handle extraction"
levsnv Nov 21, 2020
6754cbf
Merge branch 'branch-21.08' into neater-code
levsnv Jun 4, 2021
008fd55
update documentation
levsnv Jun 5, 2021
48d868b
delayed merge error fix
levsnv Jun 5, 2021
7dca8a3
python newlines
levsnv Jun 5, 2021
b8a1d1f
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into neater-code
levsnv Jun 19, 2021
f4fb48e
fixed all but static method docstring modifications
levsnv Jun 22, 2021
7e9fb87
addressed review comments; worked around static method patching
levsnv Jun 23, 2021
192aafb
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into neater-code
levsnv Jun 24, 2021
1d16b11
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into neater-code
levsnv Jun 29, 2021
db90edb
X docstring is now generated almost like auto-generated docstrings
levsnv Jun 30, 2021
b1c8e8c
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into neater-code
levsnv Jun 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ void check_params(const forest_params_t* params, bool dense) {
"leaf_algo must be FLOAT_UNARY_BINARY, CATEGORICAL_LEAF"
" or GROVE_PER_CLASS");
}
// output_t::RAW == 0, and doesn't have a separate flag
if ((params->output & ~output_t::ALL_SET) != 0) {
ASSERT(
false,
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ __device__ __forceinline__ vec<1, output_type> infer_one_tree(
curr = n.left(curr) + cond;
}
vec<1, output_type> out;
out[0] = tree[curr].base_node::output<output_type>();
/** dependent names are not considered templates by default,
unless it's a member of a current [template] instantiation.**/
out[0] = tree[curr].template output<output_type>();
return out;
}

Expand Down Expand Up @@ -499,7 +501,7 @@ struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
// or class probabilities or regression
__device__ __forceinline__ void finalize_class_label(float* out,
int num_rows) {
__syncthreads();
__syncthreads(); // make sure all votes[] are final
int item = threadIdx.x;
int row = item;
if (item < NITEMS && row < num_rows) {
Expand Down
8 changes: 5 additions & 3 deletions python/cuml/fil/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ Additionally, FIL can be called directly from C or C++ code. See [the API docs h
# Features

* Input model source: XGBoost (binary format), cuML RandomForest, scikit-learn RandomForest, LightGBM
* Model types: Regression, Binary Classification, Multi-class Classification (for cuML Random Forests, but not GBDTs or scikit-learn Random Forests)
* Model types: Regression, Binary Classification, Multi-class Classification (for cuML Random Forests or GBDTs, but not scikit-learn Random Forests)
levsnv marked this conversation as resolved.
Show resolved Hide resolved
* Tree storage types: Dense or sparse tree storage (see Sparse Forests with FIL blog below)
* Input formats: Dense, row-major, FP32 arrays on GPU or CPU (e.g. NumPy, cuPy, or other data formats supported by cuML). Trees are expected to be trained for float32 inputs. There may be rounding differences if trees were trained for float64 inputs.
* High performance batch inference
* Input parsing based on (Treelite)[https://github.com/dmlc/treelite]

Upcoming features:

* Support for multi-class GBDTs is planned for RAPIDS 0.16
* Support for multi-class random forests from scikit-learn
* Support for smaller node storage (8-byte) to reduce memory usage for
levsnv marked this conversation as resolved.
Show resolved Hide resolved
small trees is experimental
* Categorical features for LightGBM models

# Benchmarks and performance notes

Expand Down Expand Up @@ -74,5 +75,6 @@ GPU, using FIL 0.9.)
* [RAPIDS Forest Inference Library: Prediction at 100 million rows per second](https://medium.com/rapids-ai/rapids-forest-inference-library-prediction-at-100-million-rows-per-second-19558890bc35)
* [Sparse Forests with FIL](https://medium.com/rapids-ai/sparse-forests-with-fil-ffbb42b0c7e3
)
* [GBM Inferencing on GPU (earlier research work)](https://on-demand.gputechconf.com/gtc/2018/presentation/s8873-gbm-inferencing-on-gpu-v2.pdf)
* [GBM Inferencing on GPU, 2018 talk (earlier research work)](https://on-demand.gputechconf.com/gtc/2018/presentation/s8873-gbm-inferencing-on-gpu-v2.pdf)
* [Sample Notebook](https://github.com/rapidsai/cuml/blob/branch-0.16/notebooks/forest_inference_demo.ipynb)
* [GTC 2021 talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31296/)
214 changes: 60 additions & 154 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class ForestInference(Base,

**Known limitations**:
* A single row of data should fit into the shared memory of a thread
block, which means that more than 12288 features are not supported.
block, which means that more than 5000-12288 features will infer from L1
* From sklearn.ensemble, only
{RandomForest,GradientBoosting,ExtraTrees}{Classifier,Regressor} models
are supported. Other sklearn.ensemble models are currently not
Expand Down Expand Up @@ -466,6 +466,59 @@ class ForestInference(Base,

"""

common_load_params_docstring = """
output_class: boolean (default=False)
For a Classification model `output_class` must be True.
For a Regression model `output_class` must be False.
algo : string (default='auto')
Name of the algo from (from algo_t enum):

- ``'AUTO'`` or ``'auto'``: Choose the algorithm automatically.
Currently 'BATCH_TREE_REORG' is used for dense storage,
and 'NAIVE' for sparse storage
- ``'NAIVE'`` or ``'naive'``: Simple inference using shared memory
- ``'TREE_REORG'`` or ``'tree_reorg'``: Similar to naive but trees
rearranged to be more coalescing-friendly
- ``'BATCH_TREE_REORG'`` or ``'batch_tree_reorg'``: Similar to
TREE_REORG but predicting multiple rows per thread block

threshold : float (default=0.5)
Threshold is used to for classification. It is applied
only if ``output_class == True``, else it is ignored.
storage_type : string or boolean (default='auto')
In-memory storage format to be used for the FIL model:

- ``'auto'``: Choose the storage type automatically
(currently DENSE is always used)
- ``False``: Create a dense forest
- ``True``: Create a sparse forest. Requires algo='NAIVE' or
algo='AUTO'

blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.

- ``0`` (default): Launches the number of blocks proportional to
the number of data rows
- ``>= 1``: Attempts to lauch blocks_per_sm blocks per SM. This
will fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
compute_shape_str : boolean (default=False)
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
"""

common_predict_params_docstring = """
X : array-like (device or host) shape = (n_samples, n_features)
Dense matrix (floats) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
For optimal performance, pass a device array with C-style layout
"""

def __init__(self, *,
handle=None,
output_type=None,
Expand All @@ -486,11 +539,7 @@ class ForestInference(Base,

Parameters
----------
X : array-like (device or host) shape = (n_samples, n_features)
Dense matrix (floats) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
For optimal performance, pass a device array with C-style layout
""" + ForestInference.common_predict_params_docstring + """
preds: gpuarray or cudf.Series, shape = (n_samples,)
Optional 'out' location to store inference results

Expand All @@ -509,11 +558,7 @@ class ForestInference(Base,

Parameters
----------
X : array-like (device or host) shape = (n_samples, n_features)
Dense matrix (floats) of shape (n_samples, n_features).
Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
ndarray, cuda array interface compliant array like CuPy
For optimal performance, pass a device array with C-style layout
""" + ForestInference.common_predict_params_docstring + """
preds: gpuarray or cudf.Series, shape = (n_samples,2)
Binary probability output
Optional 'out' location to store inference results
Expand Down Expand Up @@ -541,53 +586,7 @@ class ForestInference(Base,
the trained model information in the treelite format
loaded from a saved model using the treelite API
https://treelite.readthedocs.io/en/latest/treelite-api.html
output_class: boolean (default=False)
For a Classification model `output_class` must be True.
For a Regression model `output_class` must be False.
algo : string (default='auto')
Name of the algo from (from algo_t enum):

- ``'AUTO'`` or ``'auto'``: choose the algorithm automatically.
Currently 'BATCH_TREE_REORG' is used for dense storage,
and 'NAIVE' for sparse storage
- ``'NAIVE'`` or ``'naive'``: simple inference using shared memory
- ``'TREE_REORG'`` or ``'tree_reorg'``: similar to naive but trees
rearranged to be more coalescing-friendly
- ``'BATCH_TREE_REORG'`` or ``'batch_tree_reorg'``: similar to
TREE_REORG but predicting multiple rows per thread block

threshold : float (default=0.5)
Threshold is used to for classification. It is applied
only if ``output_class == True``, else it is ignored.
storage_type : string or boolean (default='auto')
In-memory storage format to be used for the FIL model:

- ``'auto'``: Choose the storage type automatically
(currently DENSE is always used)
- ``False``: Create a dense forest
- ``True``: Create a sparse forest. Requires algo='NAIVE' or
algo='AUTO'
- ``'sparse8'``: (experimental) Create a sparse forest with 8-byte
nodes. Requires algo='NAIVE' or algo='AUTO'. Can fail if 8-byte
nodes are not enough to store the forest, e.g. if there are too
many nodes in a tree or too many features

blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.

- ``0`` (default): Launches the number of blocks proportional to
the number of data rows
- ``>= 1``: Attempts to lauch blocks_per_sm blocks per SM. This
will fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
compute_shape_str : boolean (default=False)
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)

""" + ForestInference.common_load_params_docstring + """
Returns
----------
fil_model
Expand Down Expand Up @@ -622,48 +621,7 @@ class ForestInference(Base,
----------
skl_model
The scikit-learn model from which to build the FIL version.
output_class: boolean (default=False)
For a Classification model `output_class` must be True.
For a Regression model `output_class` must be False.
algo : string (default='auto')
Name of the algo from (from algo_t enum):

- ``'AUTO'`` or ``'auto'``: Choose the algorithm automatically.
Currently 'BATCH_TREE_REORG' is used for dense storage,
and 'NAIVE' for sparse storage
- ``'NAIVE'`` or ``'naive'``: Simple inference using shared memory
- ``'TREE_REORG'`` or ``'tree_reorg'``: Similar to naive but trees
rearranged to be more coalescing-friendly
- ``'BATCH_TREE_REORG'`` or ``'batch_tree_reorg'``: Similar to
TREE_REORG but predicting multiple rows per thread block

threshold : float (default=0.5)
Threshold is used to for classification. It is applied
only if ``output_class == True``, else it is ignored.
storage_type : string or boolean (default='auto')
In-memory storage format to be used for the FIL model:

- ``'auto'``: Choose the storage type automatically
(currently DENSE is always used)
- ``False``: Create a dense forest
- ``True``: Create a sparse forest. Requires algo='NAIVE' or
algo='AUTO'

blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.

- ``0`` (default): Launches the number of blocks proportional to
the number of data rows
- ``>= 1``: Attempts to lauch blocks_per_sm blocks per SM. This
will fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
compute_shape_str : boolean (default=False)
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
""" + ForestInference.common_load_params_docstring + """

Returns
----------
Expand Down Expand Up @@ -702,33 +660,7 @@ class ForestInference(Base,
Path to saved model file in a treelite-compatible format
(See https://treelite.readthedocs.io/en/latest/treelite-api.html
for more information)
output_class: boolean (default=False)
For a Classification model `output_class` must be True.
For a Regression model `output_class` must be False.
threshold : float (default=0.5)
Cutoff value above which a prediction is set to 1.0
Only used if the model is classification and `output_class` is True
algo : string (default='auto')
Which inference algorithm to use.
See documentation in `FIL.load_from_treelite_model`
storage_type : string (default='auto')
In-memory storage format to be used for the FIL model.
See documentation in `FIL.load_from_treelite_model`
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.

- ``0`` (default): Launches the number of blocks proportional to
the number of data rows
- ``>= 1``: Attempts to lauch blocks_per_sm blocks per SM. This
will fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
compute_shape_str : boolean (default=False)
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
""" + ForestInference.common_load_params_docstring + """
model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
Expand Down Expand Up @@ -765,33 +697,7 @@ class ForestInference(Base,
model_handle : Modelhandle to the treelite forest model
(See https://treelite.readthedocs.io/en/latest/treelite-api.html
for more information)
output_class: boolean (default=False)
For a Classification model `output_class` must be True.
For a Regression model `output_class` must be False.
threshold : float (default=0.5)
Cutoff value above which a prediction is set to 1.0
Only used if the model is classification and `output_class` is True
algo : string (default='auto')
Which inference algorithm to use.
See documentation in `FIL.load_from_treelite_model`
storage_type : string (default='auto')
In-memory storage format to be used for the FIL model.
See documentation in `FIL.load_from_treelite_model`
blocks_per_sm : integer (default=0)
(experimental) Indicates how the number of thread blocks to lauch
for the inference kernel is determined.

- ``0`` (default): Launches the number of blocks proportional to
the number of data rows
- ``>= 1``: Attempts to lauch blocks_per_sm blocks per SM. This
will fail if blocks_per_sm blocks result in more threads than the
maximum supported number of threads per GPU. Even if successful,
it is not guaranteed that blocks_per_sm blocks will run on an SM
concurrently.
compute_shape_str : boolean (default=False)
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
""" + ForestInference.common_load_params_docstring + """

Returns
----------
Expand Down
5 changes: 1 addition & 4 deletions python/cuml/test/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,19 @@ def _build_and_save_xgboost(model_path,
dtrain = xgb.DMatrix(X_train, label=y_train)

# instantiate params
params = {'silent': 1}
params = {'eval_metric': 'error', 'max_depth': 25}

# learning task params
if classification:
params['eval_metric'] = 'error'
if n_classes == 2:
params['objective'] = 'binary:logistic'
else:
params['num_class'] = n_classes
params['objective'] = 'multi:softprob'
else:
params['eval_metric'] = 'error'
params['objective'] = 'reg:squarederror'
params['base_score'] = 0.0

params['max_depth'] = 25
params.update(xgboost_params)
bst = xgb.train(params, dtrain, num_rounds)
bst.save_model(model_path)
Expand Down