Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cuML / scikit-learn RF classifiers in TreeExplainer #4447

Merged
merged 30 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9eb9e99
Break up path processing logic in TreeSHAP explainer
hcho3 Dec 4, 2021
e715724
Fix style
hcho3 Dec 4, 2021
94db1a1
Support cuML RF classifier in TreeExplainer
hcho3 Dec 10, 2021
386b36b
Merge remote-tracking branch 'origin/branch-22.02' into classifier_su…
hcho3 Dec 14, 2021
3cde92a
Fix style
hcho3 Dec 14, 2021
7b66105
Remove print()
hcho3 Dec 14, 2021
87ebc93
Test multiple input types in test_cuml_rf_classifier
hcho3 Dec 14, 2021
21a2bb9
Test scikit-learn RF regressors and classifiers
hcho3 Dec 14, 2021
ade0448
Make scikit-learn optional
hcho3 Dec 14, 2021
e3667c4
Fix style
hcho3 Dec 14, 2021
87458a6
Consolidate path extraction logic
hcho3 Dec 17, 2021
d0dcefd
Use shap.explainers.Tree
hcho3 Dec 17, 2021
729e98d
Fix style
hcho3 Dec 18, 2021
80e45a5
Use weighted sample count in sklearn models
hcho3 Dec 20, 2021
69d6461
Add missing skipif mark
hcho3 Dec 20, 2021
cc54ae1
Extract traverse_towards_leaf_node()
hcho3 Dec 22, 2021
4da8485
Eliminate the use of reference parameter
hcho3 Dec 22, 2021
b37d638
Fix style
hcho3 Dec 22, 2021
8092a1d
Fix style
hcho3 Jan 4, 2022
5fbc641
Update copyright years
hcho3 Jan 4, 2022
743e2eb
Merge remote-tracking branch 'origin/branch-22.02' into classifier_su…
hcho3 Jan 13, 2022
d2da04e
Relax test tolerance
hcho3 Jan 13, 2022
8fc907f
Temporarily use Treelite 2.2.0 for testing
hcho3 Jan 13, 2022
f15c7a5
Temporarily use Treelite 2.2.0 for testing
hcho3 Jan 14, 2022
589d18a
Fix copyright years
hcho3 Jan 14, 2022
372b7b9
Use gpuci_conda_retry to remove metapackages
hcho3 Jan 14, 2022
651a9ae
Merge remote-tracking branch 'origin/branch-22.02' into classifier_su…
hcho3 Jan 15, 2022
0a4cc3e
Use Treelite 2.2.1
hcho3 Jan 15, 2022
5dd56c8
Merge remote-tracking branch 'origin/branch-22.02' into classifier_su…
hcho3 Jan 25, 2022
cdb46d3
Remove temporary install step in build.sh
hcho3 Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 149 additions & 73 deletions cpp/src/explainer/tree_shap.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,10 +17,12 @@
#include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/device_ptr.h>
#include <treelite/tree.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cuml/explainer/tree_shap.hpp>
#include <iostream>
#include <limits>
#include <memory>
#include <raft/error.hpp>
#include <type_traits>
Expand Down Expand Up @@ -135,8 +137,7 @@ void gpu_treeshap_impl(const TreePathInfoImpl<ThresholdType>* path_info,
DenseDatasetWrapper X(data, n_rows, n_cols);

std::size_t num_groups = 1;
if (path_info->task_type == tl::TaskType::kMultiClfGrovePerClass &&
path_info->task_param.num_class > 1) {
if (path_info->task_param.num_class > 1) {
num_groups = static_cast<std::size_t>(path_info->task_param.num_class);
}
std::size_t pred_size = n_rows * num_groups * (n_cols + 1);
Expand Down Expand Up @@ -171,88 +172,163 @@ void gpu_treeshap_impl(const TreePathInfoImpl<ThresholdType>* path_info,
namespace ML {
namespace Explainer {

// Traverse a path from the root node to a leaf node and return the list of the path segments
// Note: the path segments will have missing values in path_idx, group_id and v (leaf value).
// The callser is responsible for filling in these fields.
template <typename ThresholdType, typename LeafType>
std::vector<gpu_treeshap::PathElement<SplitCondition<ThresholdType>>> traverse_towards_leaf_node(
const tl::Tree<ThresholdType, LeafType>& tree,
int leaf_node_id,
const std::vector<int>& parent_id)
{
std::vector<gpu_treeshap::PathElement<SplitCondition<ThresholdType>>> path_segments;
int child_idx = leaf_node_id;
int parent_idx = parent_id[child_idx];
constexpr auto inf = std::numeric_limits<ThresholdType>::infinity();
tl::Operator comparison_op = tl::Operator::kNone;
while (parent_idx != -1) {
double zero_fraction = 1.0;
bool has_count_info = false;
if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) {
zero_fraction = static_cast<double>(tree.SumHess(child_idx) / tree.SumHess(parent_idx));
has_count_info = true;
}
if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) {
zero_fraction = static_cast<double>(tree.DataCount(child_idx)) / tree.DataCount(parent_idx);
has_count_info = true;
}
if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); }
// Encode the range of feature values that flow down this path
bool is_left_path = tree.LeftChild(parent_idx) == child_idx;
if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) {
RAFT_FAIL(
"Only trees with numerical splits are supported. "
"Trees with categorical splits are not supported yet.");
}
ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx);
ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf;
comparison_op = tree.ComparisonOp(parent_idx);
path_segments.push_back(gpu_treeshap::PathElement<SplitCondition<ThresholdType>>{
~std::size_t(0),
tree.SplitIndex(parent_idx),
-1,
SplitCondition{lower_bound, upper_bound, comparison_op},
zero_fraction,
std::numeric_limits<float>::quiet_NaN()});
child_idx = parent_idx;
parent_idx = parent_id[child_idx];
}
// Root node has feature -1
comparison_op = tree.ComparisonOp(child_idx);
// Build temporary path segments with unknown path_idx, group_id and leaf value
path_segments.push_back(gpu_treeshap::PathElement<SplitCondition<ThresholdType>>{
~std::size_t(0),
-1,
-1,
SplitCondition{-inf, inf, comparison_op},
1.0,
std::numeric_limits<float>::quiet_NaN()});
return path_segments;
}

// Extract the path segments from a single tree. Each path segment will have path_idx field, which
// uniquely identifies the path to which the segment belongs. The path_idx_offset parameter sets
// the path_idx field of the first path segment.
template <typename ThresholdType, typename LeafType>
std::vector<gpu_treeshap::PathElement<SplitCondition<ThresholdType>>>
extract_path_segments_from_tree(const std::vector<tl::Tree<ThresholdType, LeafType>>& tree_list,
std::size_t tree_idx,
bool use_vector_leaf,
int num_groups,
std::size_t path_idx_offset)
{
if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); }

const tl::Tree<ThresholdType, LeafType>& tree = tree_list[tree_idx];

// Compute parent ID of each node
std::vector<int> parent_id(tree.num_nodes, -1);
for (int i = 0; i < tree.num_nodes; i++) {
if (!tree.IsLeaf(i)) {
parent_id[tree.LeftChild(i)] = i;
parent_id[tree.RightChild(i)] = i;
}
}

std::size_t path_idx = path_idx_offset;
std::vector<gpu_treeshap::PathElement<SplitCondition<ThresholdType>>> path_segments;

for (int nid = 0; nid < tree.num_nodes; nid++) {
if (tree.IsLeaf(nid)) { // For each leaf node...
// Extract path segments by traversing the path from the leaf node to the root node
auto path_to_leaf = traverse_towards_leaf_node(tree, nid, parent_id);
// If use_vector_leaf=True:
// * Duplicate the path segments N times, where N = num_groups
// * Insert the duplicated path segments into path_segments
// If use_vector_leaf=False:
// * Insert the path segments into path_segments
auto path_insertor = [&path_to_leaf, &path_segments](
auto leaf_value, auto path_idx, int group_id) {
for (auto& e : path_to_leaf) {
e.path_idx = path_idx;
e.v = static_cast<float>(leaf_value);
e.group = group_id;
}
path_segments.insert(path_segments.end(), path_to_leaf.cbegin(), path_to_leaf.cend());
};
if (use_vector_leaf) {
auto leaf_vector = tree.LeafVector(nid);
if (leaf_vector.size() != static_cast<std::size_t>(num_groups)) {
RAFT_FAIL("Expected leaf vector of length %d but got %d instead",
num_groups,
static_cast<int>(leaf_vector.size()));
}
for (int group_id = 0; group_id < num_groups; ++group_id) {
path_insertor(leaf_vector[group_id], path_idx, group_id);
path_idx++;
}
} else {
auto leaf_value = tree.LeafValue(nid);
int group_id = static_cast<int>(tree_idx) % num_groups;
path_insertor(leaf_value, path_idx, group_id);
path_idx++;
}
}
}
return path_segments;
}

template <typename ThresholdType, typename LeafType>
std::unique_ptr<TreePathInfo> extract_path_info_impl(
const tl::ModelImpl<ThresholdType, LeafType>& model)
{
if (!std::is_same<ThresholdType, LeafType>::value) {
RAFT_FAIL("ThresholdType and LeafType must be identical");
}
if (model.task_type != tl::TaskType::kBinaryClfRegr &&
model.task_type != tl::TaskType::kMultiClfGrovePerClass) {
RAFT_FAIL("cuML RF / scikit-learn classifiers are not yet supported");
if (!std::is_same<ThresholdType, float>::value && !std::is_same<ThresholdType, double>::value) {
RAFT_FAIL("ThresholdType must be either float32 or float64");
}

std::unique_ptr<TreePathInfo> path_info_ptr = std::make_unique<TreePathInfoImpl<ThresholdType>>();
auto* path_info = dynamic_cast<TreePathInfoImpl<ThresholdType>*>(path_info_ptr.get());

std::size_t path_idx = 0;
int tree_idx = 0;
int num_groups = 1;
if (model.task_type == tl::TaskType::kMultiClfGrovePerClass && model.task_param.num_class > 1) {
num_groups = model.task_param.num_class;
int num_groups = 1;
bool use_vector_leaf;
if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; }
if (model.task_type == tl::TaskType::kBinaryClfRegr ||
model.task_type == tl::TaskType::kMultiClfGrovePerClass) {
use_vector_leaf = false;
} else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) {
use_vector_leaf = true;
} else {
RAFT_FAIL("Unsupported task_type: %d", static_cast<int>(model.task_type));
}
for (const tl::Tree<ThresholdType, LeafType>& tree : model.trees) {
int group_id = tree_idx % num_groups;
std::vector<int> parent_id(tree.num_nodes, -1);
// Compute parent ID of each node
for (int i = 0; i < tree.num_nodes; i++) {
if (!tree.IsLeaf(i)) {
parent_id[tree.LeftChild(i)] = i;
parent_id[tree.RightChild(i)] = i;
}
}

// Find leaf nodes
// Work backwards from leaf to root, order does not matter
// It's also possible to work from root to leaf
for (int i = 0; i < tree.num_nodes; i++) {
if (tree.IsLeaf(i)) {
auto v = static_cast<float>(tree.LeafValue(i));
int child_idx = i;
int parent_idx = parent_id[child_idx];
constexpr auto inf = std::numeric_limits<ThresholdType>::infinity();
tl::Operator comparison_op = tl::Operator::kNone;
while (parent_idx != -1) {
double zero_fraction = 1.0;
bool has_count_info = false;
if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) {
zero_fraction = static_cast<double>(tree.SumHess(child_idx) / tree.SumHess(parent_idx));
has_count_info = true;
}
if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) {
zero_fraction =
static_cast<double>(tree.DataCount(child_idx)) / tree.DataCount(parent_idx);
has_count_info = true;
}
if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); }
// Encode the range of feature values that flow down this path
bool is_left_path = tree.LeftChild(parent_idx) == child_idx;
if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) {
RAFT_FAIL(
"Only trees with numerical splits are supported. "
"Trees with categorical splits are not supported yet.");
}
ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx);
ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf;
comparison_op = tree.ComparisonOp(parent_idx);
path_info->paths.push_back(gpu_treeshap::PathElement<SplitCondition<ThresholdType>>{
path_idx,
tree.SplitIndex(parent_idx),
group_id,
SplitCondition{lower_bound, upper_bound, comparison_op},
zero_fraction,
v});
child_idx = parent_idx;
parent_idx = parent_id[child_idx];
}
// Root node has feature -1
comparison_op = tree.ComparisonOp(child_idx);
path_info->paths.push_back(gpu_treeshap::PathElement<SplitCondition<ThresholdType>>{
path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v});
path_idx++;
}
}
tree_idx++;
std::size_t path_idx = 0;
for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) {
auto path_segments =
extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, num_groups, path_idx);
path_info->paths.insert(path_info->paths.end(), path_segments.cbegin(), path_segments.cend());
if (!path_segments.empty()) { path_idx = path_segments.back().path_idx + 1; }
}
path_info->global_bias = model.param.global_bias;
path_info->task_type = model.task_type;
Expand Down
21 changes: 16 additions & 5 deletions python/cuml/experimental/explainer/tree_shap.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

from cuml.common import input_to_cuml_array
from cuml.common.array import CumlArray
from cuml.common.import_utils import has_sklearn
from cuml.common.input_utils import determine_array_type
from cuml.common.exceptions import NotFittedError
from cuml.fil.fil import TreeliteModel
Expand All @@ -28,6 +29,15 @@ from libcpp.utility cimport move
import numpy as np
import treelite

if has_sklearn():
from sklearn.ensemble import RandomForestRegressor as sklrfr
from sklearn.ensemble import RandomForestClassifier as sklrfc
else:
class PlaceHolder:
pass
sklrfr = PlaceHolder
sklrfc = PlaceHolder

cdef extern from "treelite/c_api.h":
ctypedef void* ModelHandle
cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out)
Expand Down Expand Up @@ -105,16 +115,17 @@ class TreeExplainer:
model = treelite.Model.from_xgboost(model)
handle = model.handle.value
# cuML RF model object
elif isinstance(model, curfr):
elif isinstance(model, (curfr, curfc)):
try:
model = model.convert_to_treelite_model()
except NotFittedError as e:
raise NotFittedError(
'Cannot compute SHAP for un-fitted model') from e
handle = model.handle
elif isinstance(model, curfc):
raise NotImplementedError(
'cuML RF classifiers are not supported yet')
# scikit-learn RF model object
elif isinstance(model, (sklrfr, sklrfc)):
model = treelite.sklearn.import_model(model)
handle = model.handle.value
elif isinstance(model, treelite.Model):
handle = model.handle.value
elif isinstance(model, TreeliteModel):
Expand Down
Loading