Skip to content

Commit

Permalink
Enable probability output from RF binary classifier (alternative impl…
Browse files Browse the repository at this point in the history
…ementaton) (rapidsai#3869)

Alternative implementation of rapidsai#3862 that does not depend on rapidsai#3854
Closes rapidsai#3764
Closes rapidsai#2518

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Vinay Deshpande (https://github.com/vinaydes)

URL: rapidsai#3869
  • Loading branch information
hcho3 authored May 27, 2021
1 parent 474e2e7 commit 92484fb
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 13 deletions.
30 changes: 21 additions & 9 deletions cpp/src/decisiontree/batched-levelalgo/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,28 @@ struct ClsDeviceTraits {
atomicAdd(shist + label, 1);
}
__syncthreads();
auto op = Int2Max();
int2 v = {-1, -1};
for (int i = tid; i < input.nclasses; i += blockDim.x) {
int2 tmp = {i, shist[i]};
v = op(v, tmp);
}
v = BlockReduceT(temp).Reduce(v, op);
__syncthreads();
if (tid == 0) {
nodes[0].makeLeaf(n_leaves, LabelT(v.x));
int max_class_idx = 0;
int max_count = 0;
int total_count = 0;
for (int i = 0; i < input.nclasses; ++i) {
int current_count = shist[i];
total_count += current_count;
if (current_count > max_count) {
max_class_idx = i;
max_count = current_count;
}
}
DataT aux = DataT(-1);
if (input.nclasses <= 2) {
// Special handling for binary classifiers
if (input.nclasses == 2) {
aux = static_cast<DataT>(shist[1]) / total_count;
} else {
aux = static_cast<DataT>(0);
}
}
nodes[0].makeLeaf(n_leaves, LabelT(max_class_idx), aux);
}
}
}; // struct ClsDeviceTraits
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/decisiontree/batched-levelalgo/node.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ struct Node {
*
* @note to be called only by one thread across all participating threadblocks
*/
DI void makeLeaf(IdxT* n_leaves, LabelT pred) volatile {
DI void makeLeaf(IdxT* n_leaves, LabelT pred,
DataT aux = DataT(-1)) volatile {
info.prediction = pred;
info.colid = Leaf;
info.quesval = DataT(0); // don't care for leaf nodes
info.quesval = aux;
info.best_metric_val = DataT(0); // don't care for leaf nodes
info.left_child_id = Leaf;
atomicAdd(n_leaves, 1);
Expand Down
19 changes: 18 additions & 1 deletion cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ std::string get_node_json(const std::string &prefix,
if (node.instance_count != UINT32_MAX) {
oss << ", \"instance_count\": " << node.instance_count;
}
if (node.quesval >= 0) {
oss << ", \"positive_fraction\": " << node.quesval;
}
oss << "}";
}
return oss.str();
Expand Down Expand Up @@ -202,7 +205,21 @@ tl::Tree<T, T> build_treelite_tree(

} else {
if (num_class == 1) {
tl_tree.SetLeaf(node_id, static_cast<T>(q_node.node->prediction));
if (std::is_same<decltype(q_node.node->prediction), int>::value) {
// Binary classification; use fraction of the positive class
// to produce a "soft output"
// Note. The old RF backend doesn't provide this fraction
static_assert(std::is_floating_point<T>::value,
"Expected T to be a floating-point type");
if (q_node.node->quesval >= 0) {
tl_tree.SetLeaf(node_id, static_cast<T>(q_node.node->quesval));
} else {
tl_tree.SetLeaf(node_id, static_cast<T>(q_node.node->prediction));
}
} else {
// Regression
tl_tree.SetLeaf(node_id, static_cast<T>(q_node.node->prediction));
}
} else {
std::vector<T> leaf_vector(num_class, 0);
leaf_vector[q_node.node->prediction] = 1;
Expand Down
1 change: 1 addition & 0 deletions cpp/src/decisiontree/levelalgo/levelhelper_classifier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ void leaf_eval_classification(
if (condition) {
node_flag = 0xFFFFFFFF;
sparsetree[sparsesize + sparse_nodeid].colid = -1;
sparsetree[sparsesize + sparse_nodeid].quesval = -1;
sparsetree[sparsesize + sparse_nodeid].prediction =
get_class_hist(nodehist, n_unique_labels);
} else {
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/decisiontree/levelalgo/levelkernel_classifier.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -452,6 +452,7 @@ __global__ void best_split_gather_classification_kernel(
colid = -1;
localnode.prediction =
get_class_hist_shared(shmemhist_parent, n_unique_labels);
localnode.quesval = -1;
}
localnode.colid = colid;
localnode.best_metric_val = parent_metric;
Expand Down Expand Up @@ -574,6 +575,7 @@ __global__ void best_split_gather_classification_minmax_kernel(
colid = -1;
localnode.prediction =
get_class_hist_shared(shmemhist_parent, n_unique_labels);
localnode.quesval = -1;
}
localnode.colid = colid;
localnode.best_metric_val = parent_metric;
Expand Down Expand Up @@ -611,6 +613,7 @@ __global__ void make_leaf_gather_classification_kernel(
localnode.prediction =
get_class_hist_shared(shmemhist_parent, n_unique_labels);
localnode.colid = -1;
localnode.quesval = -1;
localnode.best_metric_val = parent_metric;
d_sparsenodes[d_nodelist[blockIdx.x]] = localnode;
}
Expand Down

0 comments on commit 92484fb

Please sign in to comment.