Skip to content

Commit

Permalink
Add predict_proba() to XGBoost-style models in FIL C++ (#2894)
Browse files Browse the repository at this point in the history
Authors:
  - @levsnv

Approvers:
  - Andy Adinets (@canonizer)
  - John Zedlewski (@JohnZed)

URL: #2894
  • Loading branch information
levsnv authored Mar 10, 2021
1 parent 0967a00 commit 8b78fa3
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 215 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Please see https://github.com/rapidsai/cuml/releases/tag/branch-0.19-latest for
- PR #2659: Add initial max inner product sparse knn
- PR #3092: Multiclass meta estimator wrappers and multiclass SVC
- PR #2836: Refactor UMAP to accept sparse inputs
- PR #2894: predict_proba in FIL C++ for XGBoost-style multi-class models
- PR #3126: Experimental versions of GPU accelerated Kernel and Permutation SHAP

## Improvements
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ struct shmem_size_params {
leaf_algo_t leaf_algo = leaf_algo_t::FLOAT_UNARY_BINARY;
/// how many columns an input row has
int num_cols = 0;
/// whether to predict class probabilities or classes (or regress)
bool predict_proba = false;
/// are the input columns are prefetched into shared
/// memory before inferring the row in question
bool cols_in_shmem = true;
Expand Down Expand Up @@ -148,7 +150,9 @@ struct predict_params : shmem_size_params {
// number of data rows (instances) to predict on
size_t num_rows;

// Other parameters.
// to signal infer kernel to apply softmax and also average prior to that
// for GROVE_PER_CLASS for predict_proba
output_t transform;
int num_blocks;
};

Expand Down
103 changes: 66 additions & 37 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,23 @@ struct forest {
// searching for the most items per block while respecting the shared
// memory limits creates a full linear programming problem.
// solving it in a single equation looks less tractable than this
shmem_size_params ssp = ssp_;
for (bool cols_in_shmem : {false, true}) {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = 1;
ssp.n_items <= (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1);
++ssp.n_items) {
ssp.compute_smem_footprint();
if (ssp.shm_sz < max_shm) ssp_ = ssp;
for (bool predict_proba : {false, true}) {
shmem_size_params& ssp_ = predict_proba ? proba_ssp_ : class_ssp_;
ssp_.predict_proba = predict_proba;
shmem_size_params ssp = ssp_;
for (bool cols_in_shmem : {false, true}) {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = 1;
ssp.n_items <= (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1);
++ssp.n_items) {
ssp.compute_smem_footprint();
if (ssp.shm_sz < max_shm) ssp_ = ssp;
}
}
ASSERT(max_shm >= ssp_.shm_sz,
"FIL out of shared memory. Perhaps the maximum number of \n"
"supported classes is exceeded? 5'000 would still be safe.");
}
ASSERT(max_shm >= ssp_.shm_sz,
"FIL out of shared memory. Perhaps the maximum number of \n"
"supported classes is exceeded? 5'000 would still be safe.");
}

void init_fixed_block_count(int device, int blocks_per_sm) {
Expand All @@ -118,9 +122,10 @@ struct forest {
output_ = params->output;
threshold_ = params->threshold;
global_bias_ = params->global_bias;
ssp_.leaf_algo = params->leaf_algo;
ssp_.num_cols = params->num_cols;
ssp_.num_classes = params->num_classes;
proba_ssp_.leaf_algo = params->leaf_algo;
proba_ssp_.num_cols = params->num_cols;
proba_ssp_.num_classes = params->num_classes;
class_ssp_ = proba_ssp_;

int device = h.get_device();
init_n_items(device); // n_items takes priority over blocks_per_sm
Expand All @@ -132,11 +137,13 @@ struct forest {
void predict(const raft::handle_t& h, float* preds, const float* data,
size_t num_rows, bool predict_proba) {
// Initialize prediction parameters.
predict_params params(ssp_);
predict_params params(predict_proba ? proba_ssp_ : class_ssp_);
params.algo = algo_;
params.preds = preds;
params.data = data;
params.num_rows = num_rows;
// ignored unless predict_proba is true and algo is GROVE_PER_CLASS
params.transform = output_;
// fixed_block_count_ == 0 means the number of thread blocks is
// proportional to the number of rows
params.num_blocks = fixed_block_count_;
Expand All @@ -148,6 +155,7 @@ struct forest {
AVG is set: divide by the number of trees (averaging)
SIGMOID is set: apply sigmoid
CLASS is set: ignored
SOFTMAX is set: error
write the output of the previous stages and its complement
The binary classification / regression (FLOAT_UNARY_BINARY) predict() works as follows
Expand All @@ -156,55 +164,70 @@ struct forest {
AVG is set: divide by the number of trees (averaging)
SIGMOID is set: apply sigmoid
CLASS is set: apply threshold (equivalent to choosing best class)
SOFTMAX is set: error
The multi-class classification / regression (CATEGORICAL_LEAF) predict_proba() works as follows
(always num_classes outputs):
RAW (no values set): output class votes
AVG is set: divide by the number of trees (averaging, output class probability)
SIGMOID is set: apply sigmoid
CLASS is set: ignored
SOFTMAX is set: error
The multi-class classification / regression (CATEGORICAL_LEAF) predict() works as follows
(always 1 output):
RAW (no values set): output the label of the class with highest probability, else output label 0.
SOFTMAX is set: error
All other flags (AVG, SIGMOID, CLASS) are ignored
The multi-class classification / regression (GROVE_PER_CLASS) predict_proba() is not implemented
The multi-class classification / regression (GROVE_PER_CLASS) predict_proba() works as follows
(always num_classes outputs):
RAW (no values set): output class votes
AVG is set: divide by the number of trees (averaging, output class probability)
SIGMOID is set: apply sigmoid; if SOFTMAX is also set: error
CLASS is set: ignored
SOFTMAX is set: softmax is applied after averaging and global_bias
The multi-class classification / regression (GROVE_PER_CLASS) predict() works as follows
(always 1 output):
RAW (no values set): output the label of the class with highest margin,
equal margins resolved in favor of smaller label integer
All other flags (AVG, SIGMOID, CLASS) are ignored
All other flags (AVG, SIGMOID, CLASS, SOFTMAX) are ignored
*/
output_t ot = output_;
// Treelite applies bias before softmax, but we do after.
// Simulating treelite order, which cancels out bias.
// If non-proba prediction used, it still will not matter
// for the same reason softmax will not.
float global_bias = (ot & output_t::SOFTMAX) != 0 ? 0.0f : global_bias_;
bool complement_proba = false, do_transform;

if (predict_proba) {
// no threshold on probabilities
ot = output_t(ot & ~output_t::CLASS);

switch (ssp_.leaf_algo) {
switch (params.leaf_algo) {
case leaf_algo_t::FLOAT_UNARY_BINARY:
params.num_outputs = 2;
complement_proba = true;
do_transform = true;
break;
case leaf_algo_t::GROVE_PER_CLASS:
// TODO(levsnv): add softmax to implement predict_proba
ASSERT(
false,
"predict_proba not supported for multi-class gradient boosted "
"decision trees (encountered in xgboost, scikit-learn, lightgbm)");
// for GROVE_PER_CLASS, averaging happens in infer_k
ot = output_t(ot & ~output_t::AVG);
params.num_outputs = params.num_classes;
do_transform = ot != output_t::RAW && ot != output_t::SOFTMAX ||
global_bias != 0.0f;
break;
case leaf_algo_t::CATEGORICAL_LEAF:
params.num_outputs = ssp_.num_classes;
params.num_outputs = params.num_classes;
do_transform = ot != output_t::RAW || global_bias_ != 0.0f;
break;
default:
ASSERT(false, "internal error: invalid leaf_algo_");
}
} else {
if (ssp_.leaf_algo == leaf_algo_t::FLOAT_UNARY_BINARY) {
if (params.leaf_algo == leaf_algo_t::FLOAT_UNARY_BINARY) {
do_transform = ot != output_t::RAW || global_bias_ != 0.0f;
} else {
// GROVE_PER_CLASS, CATEGORICAL_LEAF: moot since choosing best class and
Expand All @@ -224,7 +247,7 @@ struct forest {
transform_k<<<raft::ceildiv(num_values_to_transform, (size_t)FIL_TPB),
FIL_TPB, 0, stream>>>(
preds, num_values_to_transform, ot,
num_trees_ > 0 ? (1.0f / num_trees_) : 1.0f, threshold_, global_bias_,
num_trees_ > 0 ? (1.0f / num_trees_) : 1.0f, threshold_, global_bias,
complement_proba);
CUDA_CHECK(cudaPeekAtLastError());
}
Expand All @@ -239,7 +262,7 @@ struct forest {
output_t output_ = output_t::RAW;
float threshold_ = 0.5;
float global_bias_ = 0;
shmem_size_params ssp_;
shmem_size_params class_ssp_, proba_ssp_;
int fixed_block_count_ = 0;
};

Expand Down Expand Up @@ -381,6 +404,8 @@ void check_params(const forest_params_t* params, bool dense) {
"num_classes must be 1 for "
"regression");
}
ASSERT((params->output & output_t::SOFTMAX) == 0,
"softmax does not make sense for leaf_algo == FLOAT_UNARY_BINARY");
break;
case leaf_algo_t::GROVE_PER_CLASS:
ASSERT(params->num_classes > 2,
Expand All @@ -392,19 +417,21 @@ void check_params(const forest_params_t* params, bool dense) {
ASSERT(params->num_classes >= 2,
"num_classes >= 2 is required for "
"leaf_algo == CATEGORICAL_LEAF");
ASSERT((params->output & output_t::SOFTMAX) == 0,
"softmax not supported for leaf_algo == CATEGORICAL_LEAF");
break;
default:
ASSERT(false,
"leaf_algo must be FLOAT_UNARY_BINARY, CATEGORICAL_LEAF"
" or GROVE_PER_CLASS");
}
// output_t::RAW == 0, and doesn't have a separate flag
output_t all_set =
output_t(output_t::AVG | output_t::SIGMOID | output_t::CLASS);
if ((params->output & ~all_set) != 0) {
ASSERT(false,
"output should be a combination of RAW, AVG, SIGMOID and CLASS");
if ((params->output & ~output_t::ALL_SET) != 0) {
ASSERT(
false,
"output should be a combination of RAW, AVG, SIGMOID, CLASS and SOFTMAX");
}
ASSERT(~params->output & (output_t::SIGMOID | output_t::SOFTMAX),
"combining softmax and sigmoid is not supported");
ASSERT(params->blocks_per_sm >= 0, "blocks_per_sm must be nonnegative");
}

Expand Down Expand Up @@ -654,10 +681,10 @@ void tl2fil_common(forest_params_t* params, const tl::ModelImpl<T, L>& model,
params->num_classes = static_cast<int>(model.task_param.num_class);
ASSERT(tl_params->output_class,
"output_class==true is required for multi-class models");
ASSERT(pred_transform == "sigmoid" || pred_transform == "identity" ||
ASSERT(pred_transform == "identity_multiclass" ||
pred_transform == "max_index" || pred_transform == "softmax" ||
pred_transform == "multiclass_ova",
"only sigmoid, identity, max_index, multiclass_ova and softmax "
"only identity_multiclass, max_index, multiclass_ova and softmax "
"values of pred_transform are supported for xgboost-style "
"multi-class classification models.");
// this function should not know how many threads per block will be used
Expand Down Expand Up @@ -689,9 +716,11 @@ void tl2fil_common(forest_params_t* params, const tl::ModelImpl<T, L>& model,
if (model.average_tree_output) {
params->output = output_t(params->output | output_t::AVG);
}
if (std::string(param.pred_transform) == "sigmoid") {
if (pred_transform == "sigmoid" || pred_transform == "multiclass_ova") {
params->output = output_t(params->output | output_t::SIGMOID);
}
if (pred_transform == "softmax")
params->output = output_t(params->output | output_t::SOFTMAX);
params->num_trees = model.trees.size();
params->blocks_per_sm = tl_params->blocks_per_sm;
}
Expand Down
Loading

0 comments on commit 8b78fa3

Please sign in to comment.