From 9eb9e99b128a1a02bc65e6df265fadb707219fee Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 4 Dec 2021 00:44:58 +0000 Subject: [PATCH 01/38] Break up path processing logic in TreeSHAP explainer --- cpp/src/explainer/tree_shap.cu | 133 ++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 61 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 3d54dd11ac..1a660e67f7 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -171,6 +171,77 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, namespace ML { namespace Explainer { +template +void extract_path_info_from_tree( + const tl::Tree& tree, + int num_groups, + int& tree_idx, + std::size_t& path_idx, + TreePathInfoImpl& path_info) +{ + int group_id = tree_idx % num_groups; + std::vector parent_id(tree.num_nodes, -1); + // Compute parent ID of each node + for (int i = 0; i < tree.num_nodes; i++) { + if (!tree.IsLeaf(i)) { + parent_id[tree.LeftChild(i)] = i; + parent_id[tree.RightChild(i)] = i; + } + } + + // Find leaf nodes + // Work backwards from leaf to root, order does not matter + // It's also possible to work from root to leaf + for (int i = 0; i < tree.num_nodes; i++) { + if (tree.IsLeaf(i)) { + auto v = static_cast(tree.LeafValue(i)); + int child_idx = i; + int parent_idx = parent_id[child_idx]; + constexpr auto inf = std::numeric_limits::infinity(); + tl::Operator comparison_op = tl::Operator::kNone; + while (parent_idx != -1) { + double zero_fraction = 1.0; + bool has_count_info = false; + if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + has_count_info = true; + } + if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + zero_fraction = + static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + has_count_info = true; + } + if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } + // Encode the range of feature values that flow down this path + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + RAFT_FAIL( + "Only trees with numerical splits are supported. " + "Trees with categorical splits are not supported yet."); + } + ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + comparison_op = tree.ComparisonOp(parent_idx); + path_info.paths.push_back(gpu_treeshap::PathElement>{ + path_idx, + tree.SplitIndex(parent_idx), + group_id, + SplitCondition{lower_bound, upper_bound, comparison_op}, + zero_fraction, + v}); + child_idx = parent_idx; + parent_idx = parent_id[child_idx]; + } + // Root node has feature -1 + comparison_op = tree.ComparisonOp(child_idx); + path_info.paths.push_back(gpu_treeshap::PathElement>{ + path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); + path_idx++; + } + } + tree_idx++; +} + template std::unique_ptr extract_path_info_impl( const tl::ModelImpl& model) @@ -192,67 +263,7 @@ std::unique_ptr extract_path_info_impl( num_groups = model.task_param.num_class; } for (const tl::Tree& tree : model.trees) { - int group_id = tree_idx % num_groups; - std::vector parent_id(tree.num_nodes, -1); - // Compute parent ID of each node - for (int i = 0; i < tree.num_nodes; i++) { - if (!tree.IsLeaf(i)) { - parent_id[tree.LeftChild(i)] = i; - parent_id[tree.RightChild(i)] = i; - } - } - - // Find leaf nodes - // Work backwards from leaf to root, order does not matter - // It's also possible to work from root to leaf - for (int i = 0; i < tree.num_nodes; i++) { - if (tree.IsLeaf(i)) { - auto v = static_cast(tree.LeafValue(i)); - int child_idx = i; - int parent_idx = parent_id[child_idx]; - constexpr auto inf = std::numeric_limits::infinity(); - tl::Operator comparison_op = tl::Operator::kNone; - while (parent_idx != -1) { - double zero_fraction = 1.0; - bool has_count_info = false; - if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); - has_count_info = true; - } - if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = - static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); - has_count_info = true; - } - if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } - // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); - } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - comparison_op = tree.ComparisonOp(parent_idx); - path_info->paths.push_back(gpu_treeshap::PathElement>{ - path_idx, - tree.SplitIndex(parent_idx), - group_id, - SplitCondition{lower_bound, upper_bound, comparison_op}, - zero_fraction, - v}); - child_idx = parent_idx; - parent_idx = parent_id[child_idx]; - } - // Root node has feature -1 - comparison_op = tree.ComparisonOp(child_idx); - path_info->paths.push_back(gpu_treeshap::PathElement>{ - path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); - path_idx++; - } - } - tree_idx++; + extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); } path_info->global_bias = model.param.global_bias; path_info->task_type = model.task_type; From e715724b15f586beb9938df58626666260ea75c8 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 4 Dec 2021 00:52:25 +0000 Subject: [PATCH 02/38] Fix style --- cpp/src/explainer/tree_shap.cu | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 1a660e67f7..d07023852f 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -172,12 +172,11 @@ namespace ML { namespace Explainer { template -void extract_path_info_from_tree( - const tl::Tree& tree, - int num_groups, - int& tree_idx, - std::size_t& path_idx, - TreePathInfoImpl& path_info) +void extract_path_info_from_tree(const tl::Tree& tree, + int num_groups, + int& tree_idx, + std::size_t& path_idx, + TreePathInfoImpl& path_info) { int group_id = tree_idx % num_groups; std::vector parent_id(tree.num_nodes, -1); @@ -203,7 +202,7 @@ void extract_path_info_from_tree( double zero_fraction = 1.0; bool has_count_info = false; if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); has_count_info = true; } if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { From 94db1a17c17707eda4882a87c2aebf53533794be Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 10 Dec 2021 08:23:05 +0000 Subject: [PATCH 03/38] Support cuML RF classifier in TreeExplainer --- cpp/src/explainer/tree_shap.cu | 122 ++++++++++++++++-- .../cuml/experimental/explainer/tree_shap.pyx | 5 +- .../cuml/test/explainer/test_gpu_treeshap.py | 11 +- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index d07023852f..16cea4c55c 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace tl = treelite; @@ -135,8 +137,7 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, DenseDatasetWrapper X(data, n_rows, n_cols); std::size_t num_groups = 1; - if (path_info->task_type == tl::TaskType::kMultiClfGrovePerClass && - path_info->task_param.num_class > 1) { + if (path_info->task_param.num_class > 1) { num_groups = static_cast(path_info->task_param.num_class); } std::size_t pred_size = n_rows * num_groups * (n_cols + 1); @@ -178,7 +179,6 @@ void extract_path_info_from_tree(const tl::Tree& tree, std::size_t& path_idx, TreePathInfoImpl& path_info) { - int group_id = tree_idx % num_groups; std::vector parent_id(tree.num_nodes, -1); // Compute parent ID of each node for (int i = 0; i < tree.num_nodes; i++) { @@ -191,6 +191,7 @@ void extract_path_info_from_tree(const tl::Tree& tree, // Find leaf nodes // Work backwards from leaf to root, order does not matter // It's also possible to work from root to leaf + int group_id = tree_idx % num_groups; for (int i = 0; i < tree.num_nodes; i++) { if (tree.IsLeaf(i)) { auto v = static_cast(tree.LeafValue(i)); @@ -241,6 +242,102 @@ void extract_path_info_from_tree(const tl::Tree& tree, tree_idx++; } +template +void extract_path_info_from_tree_with_leaf_vec(const tl::Tree& tree, + int num_groups, + int& tree_idx, + std::size_t& path_idx, + TreePathInfoImpl& path_info) +{ + if (num_groups < 1) { + RAFT_FAIL("num_groups must be at least 1"); + } + + std::vector parent_id(tree.num_nodes, -1); + // Compute parent ID of each node + for (int i = 0; i < tree.num_nodes; i++) { + if (!tree.IsLeaf(i)) { + parent_id[tree.LeftChild(i)] = i; + parent_id[tree.RightChild(i)] = i; + } + } + + // Find leaf nodes + // Work backwards from leaf to root, order does not matter + // It's also possible to work from root to leaf + for (int i = 0; i < tree.num_nodes; i++) { + if (tree.IsLeaf(i)) { + std::vector>> tmp_paths; + int child_idx = i; + int parent_idx = parent_id[child_idx]; + constexpr auto inf = std::numeric_limits::infinity(); + tl::Operator comparison_op = tl::Operator::kNone; + while (parent_idx != -1) { + double zero_fraction = 1.0; + bool has_count_info = false; + if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + has_count_info = true; + } + if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + zero_fraction = + static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + has_count_info = true; + } + if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } + // Encode the range of feature values that flow down this path + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + RAFT_FAIL( + "Only trees with numerical splits are supported. " + "Trees with categorical splits are not supported yet."); + } + ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + comparison_op = tree.ComparisonOp(parent_idx); + // Build temporary path segments with unknown path_idx, group_id and leaf value + tmp_paths.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + tree.SplitIndex(parent_idx), + -1, + SplitCondition{lower_bound, upper_bound, comparison_op}, + zero_fraction, + std::numeric_limits::quiet_NaN()}); + child_idx = parent_idx; + parent_idx = parent_id[child_idx]; + } + // Root node has feature -1 + comparison_op = tree.ComparisonOp(child_idx); + // Build temporary path segments with unknown path_idx, group_id and leaf value + tmp_paths.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + -1, + -1, + SplitCondition{-inf, inf, comparison_op}, + 1.0, + std::numeric_limits::quiet_NaN()}); + + // Now duplicate tmp_paths N times, where N = num_groups + // Then insert into path_info.paths + auto leaf_vector = tree.LeafVector(i); + if (leaf_vector.size() != static_cast(num_groups)) { + RAFT_FAIL("Expected leaf vector of length %d but got %d instead", + num_groups, static_cast(leaf_vector.size())); + } + for (int group_id = 0; group_id < num_groups; ++group_id) { + for (auto& e : tmp_paths) { + e.path_idx = path_idx; + e.v = static_cast(leaf_vector[group_id]); + e.group = group_id; + } + path_info.paths.insert(path_info.paths.end(), tmp_paths.begin(), tmp_paths.end()); + path_idx++; + } + } + } + tree_idx++; +} + template std::unique_ptr extract_path_info_impl( const tl::ModelImpl& model) @@ -248,9 +345,9 @@ std::unique_ptr extract_path_info_impl( if (!std::is_same::value) { RAFT_FAIL("ThresholdType and LeafType must be identical"); } - if (model.task_type != tl::TaskType::kBinaryClfRegr && - model.task_type != tl::TaskType::kMultiClfGrovePerClass) { - RAFT_FAIL("cuML RF / scikit-learn classifiers are not yet supported"); + if (!std::is_same::value && + !std::is_same::value) { + RAFT_FAIL("ThresholdType must be either float32 or float64"); } std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); @@ -258,11 +355,18 @@ std::unique_ptr extract_path_info_impl( std::size_t path_idx = 0; int tree_idx = 0; int num_groups = 1; - if (model.task_type == tl::TaskType::kMultiClfGrovePerClass && model.task_param.num_class > 1) { + if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } - for (const tl::Tree& tree : model.trees) { - extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); + if (model.task_type == tl::TaskType::kBinaryClfRegr || + model.task_type == tl::TaskType::kMultiClfGrovePerClass) { + for (const tl::Tree& tree : model.trees) { + extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); + } + } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { + for (const tl::Tree& tree : model.trees) { + extract_path_info_from_tree_with_leaf_vec(tree, num_groups, tree_idx, path_idx, *path_info); + } } path_info->global_bias = model.param.global_bias; path_info->task_type = model.task_type; diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index f42d26585a..05f8afdba5 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -105,16 +105,13 @@ class TreeExplainer: model = treelite.Model.from_xgboost(model) handle = model.handle.value # cuML RF model object - elif isinstance(model, curfr): + elif isinstance(model, (curfr, curfc)): try: model = model.convert_to_treelite_model() except NotFittedError as e: raise NotFittedError( 'Cannot compute SHAP for un-fitted model') from e handle = model.handle - elif isinstance(model, curfc): - raise NotImplementedError( - 'cuML RF classifiers are not supported yet') elif isinstance(model, treelite.Model): handle = model.handle.value elif isinstance(model, TreeliteModel): diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index fcaa19d6ca..f20674b054 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -173,8 +173,11 @@ def test_cuml_rf_classifier(n_classes): n_streams=1, n_estimators=10, max_leaves=-1, max_depth=16, accuracy_metric="mse") cuml_model.fit(X, y) + pred = np.transpose(cuml_model.predict_proba(X), (1, 0)) - with pytest.raises(RuntimeError): - # cuML RF classifier is not supported yet - explainer = TreeExplainer(model=cuml_model) - explainer.shap_values(X) + explainer = TreeExplainer(model=cuml_model) + out = explainer.shap_values(X) + print(out.shape) + # SHAP values should add up to predicted score + shap_sum = np.sum(out, axis=2) + np.tile(explainer.expected_value.reshape(-1, 1), (1, n_samples)) + np.testing.assert_almost_equal(shap_sum, pred, decimal=4) From 3cde92a07283a03a398e9edd5b9c0d9a6084fbc1 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 02:11:55 +0000 Subject: [PATCH 04/38] Fix style --- cpp/src/explainer/tree_shap.cu | 24 ++++++++----------- .../cuml/test/explainer/test_gpu_treeshap.py | 3 ++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 16cea4c55c..4e0d159ff1 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -17,16 +17,16 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include #include -#include -#include namespace tl = treelite; @@ -249,9 +249,7 @@ void extract_path_info_from_tree_with_leaf_vec(const tl::Tree& path_info) { - if (num_groups < 1) { - RAFT_FAIL("num_groups must be at least 1"); - } + if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } std::vector parent_id(tree.num_nodes, -1); // Compute parent ID of each node @@ -319,16 +317,17 @@ void extract_path_info_from_tree_with_leaf_vec(const tl::Tree(num_groups)) { RAFT_FAIL("Expected leaf vector of length %d but got %d instead", - num_groups, static_cast(leaf_vector.size())); + num_groups, + static_cast(leaf_vector.size())); } for (int group_id = 0; group_id < num_groups; ++group_id) { for (auto& e : tmp_paths) { e.path_idx = path_idx; - e.v = static_cast(leaf_vector[group_id]); - e.group = group_id; + e.v = static_cast(leaf_vector[group_id]); + e.group = group_id; } path_info.paths.insert(path_info.paths.end(), tmp_paths.begin(), tmp_paths.end()); path_idx++; @@ -345,8 +344,7 @@ std::unique_ptr extract_path_info_impl( if (!std::is_same::value) { RAFT_FAIL("ThresholdType and LeafType must be identical"); } - if (!std::is_same::value && - !std::is_same::value) { + if (!std::is_same::value && !std::is_same::value) { RAFT_FAIL("ThresholdType must be either float32 or float64"); } std::unique_ptr path_info_ptr = std::make_unique>(); @@ -355,9 +353,7 @@ std::unique_ptr extract_path_info_impl( std::size_t path_idx = 0; int tree_idx = 0; int num_groups = 1; - if (model.task_param.num_class > 1) { - num_groups = model.task_param.num_class; - } + if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } if (model.task_type == tl::TaskType::kBinaryClfRegr || model.task_type == tl::TaskType::kMultiClfGrovePerClass) { for (const tl::Tree& tree : model.trees) { diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index f20674b054..3ae0566528 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -179,5 +179,6 @@ def test_cuml_rf_classifier(n_classes): out = explainer.shap_values(X) print(out.shape) # SHAP values should add up to predicted score - shap_sum = np.sum(out, axis=2) + np.tile(explainer.expected_value.reshape(-1, 1), (1, n_samples)) + expected_value = explainer.expected_value.reshape(-1, 1) + shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) np.testing.assert_almost_equal(shap_sum, pred, decimal=4) From 7b6610519d686d3603e73cf873b589da6eb21279 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 02:24:45 +0000 Subject: [PATCH 05/38] Remove print() --- python/cuml/test/explainer/test_gpu_treeshap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 3ae0566528..c580aca319 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -177,7 +177,6 @@ def test_cuml_rf_classifier(n_classes): explainer = TreeExplainer(model=cuml_model) out = explainer.shap_values(X) - print(out.shape) # SHAP values should add up to predicted score expected_value = explainer.expected_value.reshape(-1, 1) shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) From 87ebc939906c2021a5dc761b558236be86c77757 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 02:32:29 +0000 Subject: [PATCH 06/38] Test multiple input types in test_cuml_rf_classifier --- .../cuml/test/explainer/test_gpu_treeshap.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index c580aca319..4b731d5361 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -150,34 +150,54 @@ def test_cuml_rf_regressor(input_type): explainer = TreeExplainer(model=cuml_model) out = explainer.shap_values(X) - # SHAP values should add up to predicted score - shap_sum = np.sum(out, axis=1) + explainer.expected_value if input_type == 'cupy': pred = pred.get() - shap_sum = shap_sum.get() + out = out.get() + expected_value = explainer.expected_value.get() elif input_type == 'cudf': pred = pred.to_numpy() - shap_sum = shap_sum.get() + out = out.get() + expected_value = explainer.expected_value.get() + else: + expected_value = explainer.expected_value + # SHAP values should add up to predicted score + shap_sum = np.sum(out, axis=1) + expected_value np.testing.assert_almost_equal(shap_sum, pred, decimal=4) +@pytest.mark.parametrize('input_type', ['numpy', 'cupy', 'cudf']) @pytest.mark.parametrize('n_classes', [2, 5]) -def test_cuml_rf_classifier(n_classes): +def test_cuml_rf_classifier(n_classes, input_type): n_samples = 100 X, y = make_classification(n_samples=n_samples, n_features=8, n_informative=8, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) X, y = X.astype(np.float32), y.astype(np.float32) + if input_type == 'cupy': + X, y = cp.array(X), cp.array(y) + elif input_type == 'cudf': + X, y = cudf.DataFrame(X), cudf.Series(y) cuml_model = curfc(max_features=1.0, max_samples=0.1, n_bins=128, min_samples_leaf=2, random_state=123, n_streams=1, n_estimators=10, max_leaves=-1, max_depth=16, accuracy_metric="mse") cuml_model.fit(X, y) - pred = np.transpose(cuml_model.predict_proba(X), (1, 0)) + pred = cuml_model.predict_proba(X) explainer = TreeExplainer(model=cuml_model) out = explainer.shap_values(X) + if input_type == 'cupy': + pred = pred.get() + out = out.get() + expected_value = explainer.expected_value.get() + elif input_type == 'cudf': + pred = pred.to_numpy() + out = out.get() + expected_value = explainer.expected_value.get() + else: + expected_value = explainer.expected_value # SHAP values should add up to predicted score - expected_value = explainer.expected_value.reshape(-1, 1) + expected_value = expected_value.reshape(-1, 1) shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) + pred = np.transpose(pred, (1, 0)) np.testing.assert_almost_equal(shap_sum, pred, decimal=4) From 21a2bb9d6e334596d95590442fef73cd8820adfb Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 03:20:15 +0000 Subject: [PATCH 07/38] Test scikit-learn RF regressors and classifiers --- .../cuml/test/explainer/test_gpu_treeshap.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 4b731d5361..bafe4cedad 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -25,6 +25,7 @@ from cuml.ensemble import RandomForestRegressor as curfr from cuml.ensemble import RandomForestClassifier as curfc from sklearn.datasets import make_regression, make_classification +from sklearn.ensemble import RandomForestRegressor as sklrfr if has_xgboost(): import xgboost as xgb @@ -201,3 +202,46 @@ def test_cuml_rf_classifier(n_classes, input_type): shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) pred = np.transpose(pred, (1, 0)) np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + +def test_sklearn_regressor(): + n_samples = 100 + X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, + n_targets=1, random_state=2021) + X, y = X.astype(np.float32), y.astype(np.float32) + skl_model = sklrfr(max_features=1.0, max_samples=0.1, + min_samples_leaf=2, random_state=123, + n_estimators=10, max_depth=16) + skl_model.fit(X, y) + pred = skl_model.predict(X) + + explainer = TreeExplainer(model=skl_model) + out = explainer.shap_values(X) + # SHAP values should add up to predicted score + shap_sum = np.sum(out, axis=1) + explainer.expected_value + np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + +@pytest.mark.parametrize('n_classes', [2, 3, 5]) +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_sklearn_rf_classifier(n_classes): + n_samples = 100 + X, y = make_classification(n_samples=n_samples, n_features=8, + n_informative=8, n_redundant=0, n_repeated=0, + n_classes=n_classes, random_state=2021) + X, y = X.astype(np.float32), y.astype(np.float32) + skl_model = sklrfc(max_features=1.0, max_samples=0.1, + min_samples_leaf=2, random_state=123, + n_estimators=10, max_depth=16) + skl_model.fit(X, y) + pred = skl_model.predict_proba(X) + + explainer = TreeExplainer(model=skl_model) + out = explainer.shap_values(X) + # SHAP values should add up to predicted score + if n_classes > 2: + expected_value = explainer.expected_value.reshape(-1, 1) + shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) + pred = np.transpose(pred, (1, 0)) + else: + shap_sum = np.sum(out, axis=1) + explainer.expected_value + pred = pred[:, 1] + np.testing.assert_almost_equal(shap_sum, pred, decimal=4) From ade0448f55dc347657b8e34895244f30552a26b8 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 03:20:48 +0000 Subject: [PATCH 08/38] Make scikit-learn optional --- python/cuml/experimental/explainer/tree_shap.pyx | 14 ++++++++++++++ python/cuml/test/explainer/test_gpu_treeshap.py | 13 ++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index 05f8afdba5..e24d14b4ef 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -16,6 +16,7 @@ from cuml.common import input_to_cuml_array from cuml.common.array import CumlArray +from cuml.common.import_utils import has_sklearn from cuml.common.input_utils import determine_array_type from cuml.common.exceptions import NotFittedError from cuml.fil.fil import TreeliteModel @@ -28,6 +29,15 @@ from libcpp.utility cimport move import numpy as np import treelite +if has_sklearn(): + from sklearn.ensemble import RandomForestRegressor as sklrfr + from sklearn.ensemble import RandomForestClassifier as sklrfc +else: + class PlaceHolder: + pass + sklrfr = PlaceHolder + sklrfc = PlaceHolder + cdef extern from "treelite/c_api.h": ctypedef void* ModelHandle cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out) @@ -112,6 +122,10 @@ class TreeExplainer: raise NotFittedError( 'Cannot compute SHAP for un-fitted model') from e handle = model.handle + # scikit-learn RF model object + elif isinstance(model, (sklrfr, sklrfc)): + model = treelite.sklearn.import_model(model) + handle = model.handle.value elif isinstance(model, treelite.Model): handle = model.handle.value elif isinstance(model, TreeliteModel): diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index bafe4cedad..70e114a1f3 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -20,17 +20,19 @@ import cupy as cp import cudf from cuml.experimental.explainer.tree_shap import TreeExplainer -from cuml.common.import_utils import has_xgboost, has_shap +from cuml.common.import_utils import has_xgboost, has_shap, has_sklearn from cuml.common.exceptions import NotFittedError from cuml.ensemble import RandomForestRegressor as curfr from cuml.ensemble import RandomForestClassifier as curfc -from sklearn.datasets import make_regression, make_classification -from sklearn.ensemble import RandomForestRegressor as sklrfr if has_xgboost(): import xgboost as xgb if has_shap(): import shap +if has_sklearn(): + from sklearn.datasets import make_regression, make_classification + from sklearn.ensemble import RandomForestRegressor as sklrfr + from sklearn.ensemble import RandomForestClassifier as sklrfc @pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', @@ -38,6 +40,7 @@ 'reg:pseudohubererror']) @pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") @pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_xgb_regressor(objective): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, @@ -81,6 +84,7 @@ def test_xgb_regressor(objective): 'multi:softmax', 'multi:softprob']) @pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") @pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_xgb_classifier(objective, n_classes): n_samples = 100 X, y = make_classification(n_samples=n_samples, n_features=8, @@ -133,6 +137,7 @@ def test_degenerate_cases(): @pytest.mark.parametrize('input_type', ['numpy', 'cupy', 'cudf']) +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_cuml_rf_regressor(input_type): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, @@ -168,6 +173,7 @@ def test_cuml_rf_regressor(input_type): @pytest.mark.parametrize('input_type', ['numpy', 'cupy', 'cudf']) @pytest.mark.parametrize('n_classes', [2, 5]) +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_cuml_rf_classifier(n_classes, input_type): n_samples = 100 X, y = make_classification(n_samples=n_samples, n_features=8, @@ -203,6 +209,7 @@ def test_cuml_rf_classifier(n_classes, input_type): pred = np.transpose(pred, (1, 0)) np.testing.assert_almost_equal(shap_sum, pred, decimal=4) +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_sklearn_regressor(): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, From e3667c4f6cde3d55f90df2f537cb0aabce1d812a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 14 Dec 2021 03:22:19 +0000 Subject: [PATCH 09/38] Fix style --- python/cuml/test/explainer/test_gpu_treeshap.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 70e114a1f3..223d7c1075 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -209,6 +209,7 @@ def test_cuml_rf_classifier(n_classes, input_type): pred = np.transpose(pred, (1, 0)) np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_sklearn_regressor(): n_samples = 100 @@ -227,6 +228,7 @@ def test_sklearn_regressor(): shap_sum = np.sum(out, axis=1) + explainer.expected_value np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + @pytest.mark.parametrize('n_classes', [2, 3, 5]) @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_sklearn_rf_classifier(n_classes): @@ -246,7 +248,8 @@ def test_sklearn_rf_classifier(n_classes): # SHAP values should add up to predicted score if n_classes > 2: expected_value = explainer.expected_value.reshape(-1, 1) - shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) + shap_sum = np.sum(out, axis=2) + np.tile(expected_value, + (1, n_samples)) pred = np.transpose(pred, (1, 0)) else: shap_sum = np.sum(out, axis=1) + explainer.expected_value From 87458a63487ba5c1cd5d8582091e942ce35f7c30 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 17 Dec 2021 23:27:52 +0000 Subject: [PATCH 10/38] Consolidate path extraction logic --- cpp/src/explainer/tree_shap.cu | 116 +++++++++------------------------ 1 file changed, 30 insertions(+), 86 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 4e0d159ff1..4720c9cdc2 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -172,82 +172,12 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, namespace ML { namespace Explainer { -template +template void extract_path_info_from_tree(const tl::Tree& tree, int num_groups, int& tree_idx, std::size_t& path_idx, TreePathInfoImpl& path_info) -{ - std::vector parent_id(tree.num_nodes, -1); - // Compute parent ID of each node - for (int i = 0; i < tree.num_nodes; i++) { - if (!tree.IsLeaf(i)) { - parent_id[tree.LeftChild(i)] = i; - parent_id[tree.RightChild(i)] = i; - } - } - - // Find leaf nodes - // Work backwards from leaf to root, order does not matter - // It's also possible to work from root to leaf - int group_id = tree_idx % num_groups; - for (int i = 0; i < tree.num_nodes; i++) { - if (tree.IsLeaf(i)) { - auto v = static_cast(tree.LeafValue(i)); - int child_idx = i; - int parent_idx = parent_id[child_idx]; - constexpr auto inf = std::numeric_limits::infinity(); - tl::Operator comparison_op = tl::Operator::kNone; - while (parent_idx != -1) { - double zero_fraction = 1.0; - bool has_count_info = false; - if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); - has_count_info = true; - } - if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = - static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); - has_count_info = true; - } - if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } - // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); - } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - comparison_op = tree.ComparisonOp(parent_idx); - path_info.paths.push_back(gpu_treeshap::PathElement>{ - path_idx, - tree.SplitIndex(parent_idx), - group_id, - SplitCondition{lower_bound, upper_bound, comparison_op}, - zero_fraction, - v}); - child_idx = parent_idx; - parent_idx = parent_id[child_idx]; - } - // Root node has feature -1 - comparison_op = tree.ComparisonOp(child_idx); - path_info.paths.push_back(gpu_treeshap::PathElement>{ - path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); - path_idx++; - } - } - tree_idx++; -} - -template -void extract_path_info_from_tree_with_leaf_vec(const tl::Tree& tree, - int num_groups, - int& tree_idx, - std::size_t& path_idx, - TreePathInfoImpl& path_info) { if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } @@ -263,10 +193,10 @@ void extract_path_info_from_tree_with_leaf_vec(const tl::Tree>> tmp_paths; - int child_idx = i; + int child_idx = nid; int parent_idx = parent_id[child_idx]; constexpr auto inf = std::numeric_limits::infinity(); tl::Operator comparison_op = tl::Operator::kNone; @@ -315,21 +245,35 @@ void extract_path_info_from_tree_with_leaf_vec(const tl::Tree::quiet_NaN()}); - // Now duplicate tmp_paths N times, where N = num_groups - // Then insert into path_info.paths - auto leaf_vector = tree.LeafVector(i); - if (leaf_vector.size() != static_cast(num_groups)) { - RAFT_FAIL("Expected leaf vector of length %d but got %d instead", - num_groups, - static_cast(leaf_vector.size())); - } - for (int group_id = 0; group_id < num_groups; ++group_id) { + // If use_vector_leaf=True: + // * Duplicate tmp_paths N times, where N = num_groups + // * Insert into path_info.paths + // If use_vector_leaf=False: + // * Insert tmp_paths into path_info.paths + auto path_insertor = [&tmp_paths, &path_info]( + auto leaf_value, auto path_idx, int group_id) { for (auto& e : tmp_paths) { e.path_idx = path_idx; - e.v = static_cast(leaf_vector[group_id]); + e.v = static_cast(leaf_value); e.group = group_id; } path_info.paths.insert(path_info.paths.end(), tmp_paths.begin(), tmp_paths.end()); + }; + if constexpr (use_vector_leaf) { + auto leaf_vector = tree.LeafVector(nid); + if (leaf_vector.size() != static_cast(num_groups)) { + RAFT_FAIL("Expected leaf vector of length %d but got %d instead", + num_groups, + static_cast(leaf_vector.size())); + } + for (int group_id = 0; group_id < num_groups; ++group_id) { + path_insertor(leaf_vector[group_id], path_idx, group_id); + path_idx++; + } + } else { + auto leaf_value = tree.LeafValue(nid); + int group_id = tree_idx % num_groups; + path_insertor(leaf_value, path_idx, group_id); path_idx++; } } @@ -357,11 +301,11 @@ std::unique_ptr extract_path_info_impl( if (model.task_type == tl::TaskType::kBinaryClfRegr || model.task_type == tl::TaskType::kMultiClfGrovePerClass) { for (const tl::Tree& tree : model.trees) { - extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); + extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); } } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { for (const tl::Tree& tree : model.trees) { - extract_path_info_from_tree_with_leaf_vec(tree, num_groups, tree_idx, path_idx, *path_info); + extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); } } path_info->global_bias = model.param.global_bias; From d0dcefd82989499d06fb5f54097eb349503dbc43 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 17 Dec 2021 23:59:28 +0000 Subject: [PATCH 11/38] Use shap.explainers.Tree --- python/cuml/test/explainer/test_gpu_treeshap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 223d7c1075..034ff2be8f 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -61,7 +61,7 @@ def test_xgb_regressor(objective): explainer = TreeExplainer(model=tl_model) out = explainer.shap_values(X) - ref_explainer = shap.TreeExplainer(model=xgb_model) + ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out) np.testing.assert_almost_equal(explainer.expected_value, @@ -105,7 +105,7 @@ def test_xgb_classifier(objective, n_classes): explainer = TreeExplainer(model=xgb_model) out = explainer.shap_values(X) - ref_explainer = shap.TreeExplainer(model=xgb_model) + ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out) np.testing.assert_almost_equal(explainer.expected_value, From 729e98d2304cf7da6cbd04f5f18201a01bc05350 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 18 Dec 2021 00:00:31 +0000 Subject: [PATCH 12/38] Fix style --- cpp/src/explainer/tree_shap.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 4720c9cdc2..8d4dadf7d9 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -250,8 +250,7 @@ void extract_path_info_from_tree(const tl::Tree& tree, // * Insert into path_info.paths // If use_vector_leaf=False: // * Insert tmp_paths into path_info.paths - auto path_insertor = [&tmp_paths, &path_info]( - auto leaf_value, auto path_idx, int group_id) { + auto path_insertor = [&tmp_paths, &path_info](auto leaf_value, auto path_idx, int group_id) { for (auto& e : tmp_paths) { e.path_idx = path_idx; e.v = static_cast(leaf_value); @@ -272,7 +271,7 @@ void extract_path_info_from_tree(const tl::Tree& tree, } } else { auto leaf_value = tree.LeafValue(nid); - int group_id = tree_idx % num_groups; + int group_id = tree_idx % num_groups; path_insertor(leaf_value, path_idx, group_id); path_idx++; } From 80e45a511a03d9e24a5b2818e34189b0091ae2bb Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 20 Dec 2021 12:20:56 +0000 Subject: [PATCH 13/38] Use weighted sample count in sklearn models --- cpp/src/explainer/tree_shap.cu | 5 +-- .../cuml/test/explainer/test_gpu_treeshap.py | 31 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 8d4dadf7d9..70aac973e8 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -207,7 +207,7 @@ void extract_path_info_from_tree(const tl::Tree& tree, zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); has_count_info = true; } - if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { zero_fraction = static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); has_count_info = true; @@ -256,7 +256,7 @@ void extract_path_info_from_tree(const tl::Tree& tree, e.v = static_cast(leaf_value); e.group = group_id; } - path_info.paths.insert(path_info.paths.end(), tmp_paths.begin(), tmp_paths.end()); + path_info.paths.insert(path_info.paths.end(), tmp_paths.cbegin(), tmp_paths.cend()); }; if constexpr (use_vector_leaf) { auto leaf_vector = tree.LeafVector(nid); @@ -290,6 +290,7 @@ std::unique_ptr extract_path_info_impl( if (!std::is_same::value && !std::is_same::value) { RAFT_FAIL("ThresholdType must be either float32 or float64"); } + std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 034ff2be8f..eeb44499ab 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -211,7 +211,7 @@ def test_cuml_rf_classifier(n_classes, input_type): @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") -def test_sklearn_regressor(): +def test_sklearn_rf_regressor(): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, n_targets=1, random_state=2021) @@ -224,9 +224,12 @@ def test_sklearn_regressor(): explainer = TreeExplainer(model=skl_model) out = explainer.shap_values(X) - # SHAP values should add up to predicted score - shap_sum = np.sum(out, axis=1) + explainer.expected_value - np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + + ref_explainer = shap.explainers.Tree(model=skl_model) + correct_out = ref_explainer.shap_values(X) + np.testing.assert_almost_equal(out, correct_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_explainer.expected_value, decimal=5) @pytest.mark.parametrize('n_classes', [2, 3, 5]) @@ -245,13 +248,13 @@ def test_sklearn_rf_classifier(n_classes): explainer = TreeExplainer(model=skl_model) out = explainer.shap_values(X) - # SHAP values should add up to predicted score - if n_classes > 2: - expected_value = explainer.expected_value.reshape(-1, 1) - shap_sum = np.sum(out, axis=2) + np.tile(expected_value, - (1, n_samples)) - pred = np.transpose(pred, (1, 0)) - else: - shap_sum = np.sum(out, axis=1) + explainer.expected_value - pred = pred[:, 1] - np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + + ref_explainer = shap.explainers.Tree(model=skl_model) + correct_out = np.array(ref_explainer.shap_values(X)) + expected_value = ref_explainer.expected_value + if n_classes == 2: + correct_out = correct_out[1, :, :] + expected_value = expected_value[1:] + np.testing.assert_almost_equal(out, correct_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + expected_value, decimal=5) From 69d646155937259b8a8e3c2a08f6d2fa6950d229 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 20 Dec 2021 12:38:06 +0000 Subject: [PATCH 14/38] Add missing skipif mark --- python/cuml/test/explainer/test_gpu_treeshap.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index eeb44499ab..b179c9871d 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -210,6 +210,7 @@ def test_cuml_rf_classifier(n_classes, input_type): np.testing.assert_almost_equal(shap_sum, pred, decimal=4) +@pytest.mark.skipif(not has_shap(), reason="need to install shap") @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_sklearn_rf_regressor(): n_samples = 100 @@ -233,6 +234,7 @@ def test_sklearn_rf_regressor(): @pytest.mark.parametrize('n_classes', [2, 3, 5]) +@pytest.mark.skipif(not has_shap(), reason="need to install shap") @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_sklearn_rf_classifier(n_classes): n_samples = 100 From cc54ae12ba9c8f48b09dce937b6050fffcd6c249 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 22 Dec 2021 01:03:55 +0000 Subject: [PATCH 15/38] Extract traverse_towards_leaf_node() --- cpp/src/explainer/tree_shap.cu | 128 +++++++++++++++++---------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 70aac973e8..4644d6024d 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -172,6 +172,65 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, namespace ML { namespace Explainer { +// Traverse a path from the root node to a leaf node and return the list of the path segments +// Note: the path segments will have missing values in path_idx, group_id and v (leaf value). +// The callser is responsible for filling in these fields. +template +std::vector>> +traverse_towards_leaf_node(const tl::Tree& tree, + int leaf_node_id, + const std::vector& parent_id) { + std::vector>> path_segments; + int child_idx = leaf_node_id; + int parent_idx = parent_id[child_idx]; + constexpr auto inf = std::numeric_limits::infinity(); + tl::Operator comparison_op = tl::Operator::kNone; + while (parent_idx != -1) { + double zero_fraction = 1.0; + bool has_count_info = false; + if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + has_count_info = true; + } + if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + zero_fraction = + static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + has_count_info = true; + } + if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } + // Encode the range of feature values that flow down this path + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + RAFT_FAIL( + "Only trees with numerical splits are supported. " + "Trees with categorical splits are not supported yet."); + } + ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + comparison_op = tree.ComparisonOp(parent_idx); + path_segments.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + tree.SplitIndex(parent_idx), + -1, + SplitCondition{lower_bound, upper_bound, comparison_op}, + zero_fraction, + std::numeric_limits::quiet_NaN()}); + child_idx = parent_idx; + parent_idx = parent_id[child_idx]; + } + // Root node has feature -1 + comparison_op = tree.ComparisonOp(child_idx); + // Build temporary path segments with unknown path_idx, group_id and leaf value + path_segments.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + -1, + -1, + SplitCondition{-inf, inf, comparison_op}, + 1.0, + std::numeric_limits::quiet_NaN()}); + return path_segments; +} + template void extract_path_info_from_tree(const tl::Tree& tree, int num_groups, @@ -181,8 +240,8 @@ void extract_path_info_from_tree(const tl::Tree& tree, { if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } - std::vector parent_id(tree.num_nodes, -1); // Compute parent ID of each node + std::vector parent_id(tree.num_nodes, -1); for (int i = 0; i < tree.num_nodes; i++) { if (!tree.IsLeaf(i)) { parent_id[tree.LeftChild(i)] = i; @@ -190,73 +249,22 @@ void extract_path_info_from_tree(const tl::Tree& tree, } } - // Find leaf nodes - // Work backwards from leaf to root, order does not matter - // It's also possible to work from root to leaf for (int nid = 0; nid < tree.num_nodes; nid++) { - if (tree.IsLeaf(nid)) { - std::vector>> tmp_paths; - int child_idx = nid; - int parent_idx = parent_id[child_idx]; - constexpr auto inf = std::numeric_limits::infinity(); - tl::Operator comparison_op = tl::Operator::kNone; - while (parent_idx != -1) { - double zero_fraction = 1.0; - bool has_count_info = false; - if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); - has_count_info = true; - } - if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = - static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); - has_count_info = true; - } - if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } - // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); - } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - comparison_op = tree.ComparisonOp(parent_idx); - // Build temporary path segments with unknown path_idx, group_id and leaf value - tmp_paths.push_back(gpu_treeshap::PathElement>{ - ~std::size_t(0), - tree.SplitIndex(parent_idx), - -1, - SplitCondition{lower_bound, upper_bound, comparison_op}, - zero_fraction, - std::numeric_limits::quiet_NaN()}); - child_idx = parent_idx; - parent_idx = parent_id[child_idx]; - } - // Root node has feature -1 - comparison_op = tree.ComparisonOp(child_idx); - // Build temporary path segments with unknown path_idx, group_id and leaf value - tmp_paths.push_back(gpu_treeshap::PathElement>{ - ~std::size_t(0), - -1, - -1, - SplitCondition{-inf, inf, comparison_op}, - 1.0, - std::numeric_limits::quiet_NaN()}); - + if (tree.IsLeaf(nid)) { // For each leaf node... + // Extract path segments by traversing the path from the leaf node to the root node + auto path_segments = traverse_towards_leaf_node(tree, nid, parent_id); // If use_vector_leaf=True: - // * Duplicate tmp_paths N times, where N = num_groups + // * Duplicate the path segments N times, where N = num_groups // * Insert into path_info.paths // If use_vector_leaf=False: - // * Insert tmp_paths into path_info.paths - auto path_insertor = [&tmp_paths, &path_info](auto leaf_value, auto path_idx, int group_id) { - for (auto& e : tmp_paths) { + // * Insert the path segments into path_info.paths + auto path_insertor = [&path_segments, &path_info](auto leaf_value, auto path_idx, int group_id) { + for (auto& e : path_segments) { e.path_idx = path_idx; e.v = static_cast(leaf_value); e.group = group_id; } - path_info.paths.insert(path_info.paths.end(), tmp_paths.cbegin(), tmp_paths.cend()); + path_info.paths.insert(path_info.paths.end(), path_segments.cbegin(), path_segments.cend()); }; if constexpr (use_vector_leaf) { auto leaf_vector = tree.LeafVector(nid); From 4da8485541bf4f86ab870a9f9032b6b57a5dc88e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 22 Dec 2021 01:49:43 +0000 Subject: [PATCH 16/38] Eliminate the use of reference parameter --- cpp/src/explainer/tree_shap.cu | 64 +++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 4644d6024d..2594961c18 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -179,7 +179,8 @@ template std::vector>> traverse_towards_leaf_node(const tl::Tree& tree, int leaf_node_id, - const std::vector& parent_id) { + const std::vector& parent_id) +{ std::vector>> path_segments; int child_idx = leaf_node_id; int parent_idx = parent_id[child_idx]; @@ -231,15 +232,21 @@ traverse_towards_leaf_node(const tl::Tree& tree, return path_segments; } -template -void extract_path_info_from_tree(const tl::Tree& tree, - int num_groups, - int& tree_idx, - std::size_t& path_idx, - TreePathInfoImpl& path_info) +// Extract the path segments from a single tree. Each path segment will have path_idx field, which +// uniquely identifies the path to which the segment belongs. The path_idx_offset parameter sets +// the path_idx field of the first path segment. +template +std::vector>> +extract_path_segments_from_tree(const std::vector>& tree_list, + std::size_t tree_idx, + bool use_vector_leaf, + int num_groups, + std::size_t path_idx_offset) { if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } + const tl::Tree& tree = tree_list[tree_idx]; + // Compute parent ID of each node std::vector parent_id(tree.num_nodes, -1); for (int i = 0; i < tree.num_nodes; i++) { @@ -249,24 +256,27 @@ void extract_path_info_from_tree(const tl::Tree& tree, } } + std::size_t path_idx = path_idx_offset; + std::vector>> path_segments; + for (int nid = 0; nid < tree.num_nodes; nid++) { if (tree.IsLeaf(nid)) { // For each leaf node... // Extract path segments by traversing the path from the leaf node to the root node - auto path_segments = traverse_towards_leaf_node(tree, nid, parent_id); + auto path_to_leaf = traverse_towards_leaf_node(tree, nid, parent_id); // If use_vector_leaf=True: // * Duplicate the path segments N times, where N = num_groups - // * Insert into path_info.paths + // * Insert the duplicated path segments into path_segments // If use_vector_leaf=False: - // * Insert the path segments into path_info.paths - auto path_insertor = [&path_segments, &path_info](auto leaf_value, auto path_idx, int group_id) { - for (auto& e : path_segments) { + // * Insert the path segments into path_segments + auto path_insertor = [&path_to_leaf, &path_segments](auto leaf_value, auto path_idx, int group_id) { + for (auto& e : path_to_leaf) { e.path_idx = path_idx; e.v = static_cast(leaf_value); e.group = group_id; } - path_info.paths.insert(path_info.paths.end(), path_segments.cbegin(), path_segments.cend()); + path_segments.insert(path_segments.end(), path_to_leaf.cbegin(), path_to_leaf.cend()); }; - if constexpr (use_vector_leaf) { + if (use_vector_leaf) { auto leaf_vector = tree.LeafVector(nid); if (leaf_vector.size() != static_cast(num_groups)) { RAFT_FAIL("Expected leaf vector of length %d but got %d instead", @@ -279,13 +289,13 @@ void extract_path_info_from_tree(const tl::Tree& tree, } } else { auto leaf_value = tree.LeafValue(nid); - int group_id = tree_idx % num_groups; + int group_id = static_cast(tree_idx) % num_groups; path_insertor(leaf_value, path_idx, group_id); path_idx++; } } } - tree_idx++; + return path_segments; } template @@ -302,18 +312,24 @@ std::unique_ptr extract_path_info_impl( std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); - std::size_t path_idx = 0; - int tree_idx = 0; - int num_groups = 1; + int num_groups = 1; + bool use_vector_leaf; if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } if (model.task_type == tl::TaskType::kBinaryClfRegr || model.task_type == tl::TaskType::kMultiClfGrovePerClass) { - for (const tl::Tree& tree : model.trees) { - extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); - } + use_vector_leaf = false; } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { - for (const tl::Tree& tree : model.trees) { - extract_path_info_from_tree(tree, num_groups, tree_idx, path_idx, *path_info); + use_vector_leaf = true; + } else { + RAFT_FAIL("Unsupported task_type: %d", static_cast(model.task_type)); + } + std::size_t path_idx = 0; + for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) { + auto path_segments = extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, + num_groups, path_idx); + path_info->paths.insert(path_info->paths.end(), path_segments.cbegin(), path_segments.cend()); + if (!path_segments.empty()) { + path_idx = path_segments.back().path_idx + 1; } } path_info->global_bias = model.param.global_bias; From b37d6385c01a1d7a3b37cc6d5e920e71ed7fc6d0 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 22 Dec 2021 01:51:32 +0000 Subject: [PATCH 17/38] Fix style --- cpp/src/explainer/tree_shap.cu | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 2594961c18..3f8d271505 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -176,10 +176,10 @@ namespace Explainer { // Note: the path segments will have missing values in path_idx, group_id and v (leaf value). // The callser is responsible for filling in these fields. template -std::vector>> -traverse_towards_leaf_node(const tl::Tree& tree, - int leaf_node_id, - const std::vector& parent_id) +std::vector>> traverse_towards_leaf_node( + const tl::Tree& tree, + int leaf_node_id, + const std::vector& parent_id) { std::vector>> path_segments; int child_idx = leaf_node_id; @@ -194,8 +194,7 @@ traverse_towards_leaf_node(const tl::Tree& tree, has_count_info = true; } if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = - static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + zero_fraction = static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); has_count_info = true; } if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } @@ -260,7 +259,7 @@ extract_path_segments_from_tree(const std::vector>> path_segments; for (int nid = 0; nid < tree.num_nodes; nid++) { - if (tree.IsLeaf(nid)) { // For each leaf node... + if (tree.IsLeaf(nid)) { // For each leaf node... // Extract path segments by traversing the path from the leaf node to the root node auto path_to_leaf = traverse_towards_leaf_node(tree, nid, parent_id); // If use_vector_leaf=True: @@ -268,7 +267,8 @@ extract_path_segments_from_tree(const std::vector(leaf_value); @@ -325,12 +325,10 @@ std::unique_ptr extract_path_info_impl( } std::size_t path_idx = 0; for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) { - auto path_segments = extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, - num_groups, path_idx); + auto path_segments = + extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, num_groups, path_idx); path_info->paths.insert(path_info->paths.end(), path_segments.cbegin(), path_segments.cend()); - if (!path_segments.empty()) { - path_idx = path_segments.back().path_idx + 1; - } + if (!path_segments.empty()) { path_idx = path_segments.back().path_idx + 1; } } path_info->global_bias = model.param.global_bias; path_info->task_type = model.task_type; From 8092a1d83745bfb4cc0c59453b544259fad40da8 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 4 Jan 2022 22:11:37 +0000 Subject: [PATCH 18/38] Fix style --- python/cuml/test/explainer/test_gpu_treeshap.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index b179c9871d..a26d9ddaed 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -221,7 +221,6 @@ def test_sklearn_rf_regressor(): min_samples_leaf=2, random_state=123, n_estimators=10, max_depth=16) skl_model.fit(X, y) - pred = skl_model.predict(X) explainer = TreeExplainer(model=skl_model) out = explainer.shap_values(X) @@ -246,7 +245,6 @@ def test_sklearn_rf_classifier(n_classes): min_samples_leaf=2, random_state=123, n_estimators=10, max_depth=16) skl_model.fit(X, y) - pred = skl_model.predict_proba(X) explainer = TreeExplainer(model=skl_model) out = explainer.shap_values(X) From 5fbc641d28023c729ad32d735bb54ef13aba428e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 4 Jan 2022 23:01:01 +0000 Subject: [PATCH 19/38] Update copyright years --- cpp/src/explainer/tree_shap.cu | 2 +- python/cuml/experimental/explainer/tree_shap.pyx | 2 +- python/cuml/test/explainer/test_gpu_treeshap.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 3f8d271505..25810d3045 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index e24d14b4ef..8e0d8b29dd 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index a26d9ddaed..6ac43d6766 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From d2da04e0780aee9a8e9228588c81157aea358a6f Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 13 Jan 2022 17:01:22 +0000 Subject: [PATCH 20/38] Relax test tolerance --- python/cuml/test/explainer/test_gpu_treeshap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 6ac43d6766..5d790185f1 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -63,7 +63,7 @@ def test_xgb_regressor(objective): ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) - np.testing.assert_almost_equal(out, correct_out) + np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_explainer.expected_value) @@ -107,7 +107,7 @@ def test_xgb_classifier(objective, n_classes): ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) - np.testing.assert_almost_equal(out, correct_out) + np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_explainer.expected_value) From 8fc907f047d89c774237cc68c4a7da2fbb3ec489 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 13 Jan 2022 20:16:23 +0000 Subject: [PATCH 21/38] Temporarily use Treelite 2.2.0 for testing --- ci/gpu/build.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 5433540c8f..665df8858d 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -64,6 +64,11 @@ gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid "rapids-doc-env=${MINOR_VERSION}.*" \ "shap>=0.37,<=0.39" +gpuci_mamba_retry remove -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ + --force rapids-build-env rapids-notebook-env +gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ + treelite=2.2.0 + # https://docs.rapids.ai/maintainers/depmgmt/ # gpuci_mamba_retry remove --force rapids-build-env rapids-notebook-env # gpuci_mamba_retry install -y "your-pkg=1.0.0" From f15c7a5e26852e0d6913eb4cee3a2c8be5d53892 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 14 Jan 2022 00:58:12 +0000 Subject: [PATCH 22/38] Temporarily use Treelite 2.2.0 for testing --- conda/recipes/cuml/meta.yaml | 4 ++-- conda/recipes/libcuml/meta.yaml | 4 ++-- cpp/cmake/thirdparty/get_treelite.cmake | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 3d10993733..c56f31005f 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -30,7 +30,7 @@ requirements: - setuptools - cython>=0.29,<0.30 - cmake>=3.20.1 - - treelite=2.1.0 + - treelite=2.2.0 - cudf {{ minor_version }} - libcuml={{ version }} - libcumlprims {{ minor_version }} @@ -44,7 +44,7 @@ requirements: - libcuml={{ version }} - libcumlprims {{ minor_version }} - cupy>=7.8.0,<10.0.0a0 - - treelite=2.1.0 + - treelite=2.2.0 - nccl>=2.9.9 - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu diff --git a/conda/recipes/libcuml/meta.yaml b/conda/recipes/libcuml/meta.yaml index db323560fe..e1814d1237 100644 --- a/conda/recipes/libcuml/meta.yaml +++ b/conda/recipes/libcuml/meta.yaml @@ -43,7 +43,7 @@ requirements: - ucx-proc=*=gpu - libcumlprims {{ minor_version }} - lapack - - treelite=2.1.0 + - treelite=2.2.0 - faiss-proc=*=cuda - gtest=1.10.0 - gmock @@ -55,7 +55,7 @@ requirements: - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} - - treelite=2.1.0 + - treelite=2.2.0 - faiss-proc=*=cuda - libfaiss 1.7.0 *_cuda - libcusolver>=11.2.1 diff --git a/cpp/cmake/thirdparty/get_treelite.cmake b/cpp/cmake/thirdparty/get_treelite.cmake index 171706ea20..479e95936e 100644 --- a/cpp/cmake/thirdparty/get_treelite.cmake +++ b/cpp/cmake/thirdparty/get_treelite.cmake @@ -54,5 +54,5 @@ function(find_and_configure_treelite) endfunction() -find_and_configure_treelite(VERSION 2.1.0 - PINNED_TAG e5248931c62e3807248e0b150e27b2530a510634) +find_and_configure_treelite(VERSION 2.2.0 + PINNED_TAG 2a62c6f569a6b12520cbfc35b4a727bb6bb671b1) From 589d18a77e21762d1731309fa2894dc09e477576 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 14 Jan 2022 02:01:47 +0000 Subject: [PATCH 23/38] Fix copyright years --- conda/recipes/libcuml/meta.yaml | 2 +- cpp/cmake/thirdparty/get_treelite.cmake | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda/recipes/libcuml/meta.yaml b/conda/recipes/libcuml/meta.yaml index e1814d1237..05877c35e9 100644 --- a/conda/recipes/libcuml/meta.yaml +++ b/conda/recipes/libcuml/meta.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. # Usage: # conda build . -c conda-forge -c nvidia -c rapidsai -c pytorch diff --git a/cpp/cmake/thirdparty/get_treelite.cmake b/cpp/cmake/thirdparty/get_treelite.cmake index 479e95936e..bcbf65aa6b 100644 --- a/cpp/cmake/thirdparty/get_treelite.cmake +++ b/cpp/cmake/thirdparty/get_treelite.cmake @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 372b7b97440bea791b7b0e97b528c9919b28a893 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 14 Jan 2022 02:02:36 +0000 Subject: [PATCH 24/38] Use gpuci_conda_retry to remove metapackages --- ci/gpu/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 665df8858d..98fc750253 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -64,7 +64,7 @@ gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid "rapids-doc-env=${MINOR_VERSION}.*" \ "shap>=0.37,<=0.39" -gpuci_mamba_retry remove -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ +gpuci_conda_retry remove -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ --force rapids-build-env rapids-notebook-env gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ treelite=2.2.0 From 0a4cc3e7369d3544d1408ecb19d416404912b18c Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 15 Jan 2022 18:49:26 +0000 Subject: [PATCH 25/38] Use Treelite 2.2.1 --- ci/gpu/build.sh | 2 +- conda/recipes/cuml/meta.yaml | 4 ++-- conda/recipes/libcuml/meta.yaml | 4 ++-- cpp/cmake/thirdparty/get_treelite.cmake | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 98fc750253..69d5022629 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -67,7 +67,7 @@ gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid gpuci_conda_retry remove -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ --force rapids-build-env rapids-notebook-env gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ - treelite=2.2.0 + treelite=2.2.1 # https://docs.rapids.ai/maintainers/depmgmt/ # gpuci_mamba_retry remove --force rapids-build-env rapids-notebook-env diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 16fa766258..e0bc355919 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -30,7 +30,7 @@ requirements: - setuptools - cython>=0.29,<0.30 - cmake>=3.20.1 - - treelite=2.2.0 + - treelite=2.2.1 - cudf {{ minor_version }} - libcuml={{ version }} - libcumlprims {{ minor_version }} @@ -44,7 +44,7 @@ requirements: - libcuml={{ version }} - libcumlprims {{ minor_version }} - cupy>=7.8.0,<11.0.0a0 - - treelite=2.2.0 + - treelite=2.2.1 - nccl>=2.9.9 - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu diff --git a/conda/recipes/libcuml/meta.yaml b/conda/recipes/libcuml/meta.yaml index 05877c35e9..9f395b637b 100644 --- a/conda/recipes/libcuml/meta.yaml +++ b/conda/recipes/libcuml/meta.yaml @@ -43,7 +43,7 @@ requirements: - ucx-proc=*=gpu - libcumlprims {{ minor_version }} - lapack - - treelite=2.2.0 + - treelite=2.2.1 - faiss-proc=*=cuda - gtest=1.10.0 - gmock @@ -55,7 +55,7 @@ requirements: - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} - - treelite=2.2.0 + - treelite=2.2.1 - faiss-proc=*=cuda - libfaiss 1.7.0 *_cuda - libcusolver>=11.2.1 diff --git a/cpp/cmake/thirdparty/get_treelite.cmake b/cpp/cmake/thirdparty/get_treelite.cmake index bcbf65aa6b..4faf14869d 100644 --- a/cpp/cmake/thirdparty/get_treelite.cmake +++ b/cpp/cmake/thirdparty/get_treelite.cmake @@ -54,5 +54,5 @@ function(find_and_configure_treelite) endfunction() -find_and_configure_treelite(VERSION 2.2.0 - PINNED_TAG 2a62c6f569a6b12520cbfc35b4a727bb6bb671b1) +find_and_configure_treelite(VERSION 2.2.1 + PINNED_TAG 1f9c5054ad7433fa88623fccf0ae46a6ba6a27c6) From 9045c5fc92dba4aa6467d3992d38b915677b0d18 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 04:00:29 +0000 Subject: [PATCH 26/38] Implement Span and BitField --- cpp/src/explainer/tree_shap.cu | 95 ++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 8ebe74efbb..963713c2ab 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,100 @@ namespace tl = treelite; * for internal use by GPUTreeSHAP. */ namespace { +// A poor man's Span class. +// TODO(hcho3): Remove this class once RAFT implements a span abstraction. +template +class Span { + private: + T* ptr_{nullptr}; + std::size_t size_{0}; + + public: + Span() = default; + __host__ __device__ Span(T* ptr, std::size_t size) : ptr_(ptr), size_(size) {} + __host__ explicit Span(std::vector& vec) : ptr_(vec.data()), size_(vec.size()) {} + __host__ explicit Span(thrust::device_vector& vec) + : ptr_(thrust::raw_pointer_cast(vec.data())), size_(vec.size()) + { + } + __host__ __device__ Span(const Span& other) : ptr_(other.ptr_), size_(other.size_) {} + __host__ __device__ Span(Span&& other) : ptr_(other.ptr_), size_(other.size_) + { + other.ptr_ = nullptr; + other.size_ = 0; + } + __host__ __device__ ~Span() {} + __host__ __device__ Span& operator=(const Span& other) + { + ptr_ = other.ptr_; + size_ = other.size_; + return *this; + } + __host__ __device__ Span& operator=(Span&& other) + { + ptr_ = other.ptr_; + size_ = other.size_; + other.ptr_ = nullptr; + other.size_ = 0; + return *this; + } + __host__ __device__ std::size_t Size() const { return size_; } + __host__ __device__ T* Data() const { return ptr_; } + __host__ __device__ T& operator[](std::size_t offset) const { return *(ptr_ + offset); } + __host__ __device__ Span Subspan(std::size_t offset, std::size_t count) + { + return Span{ptr_ + offset, count}; + } +}; + +// A poor man's bit field, to be used to account for categorical splits in SHAP computation +// Inspired by xgboost::BitFieldContainer +template +class BitField { + private: + static std::size_t constexpr kValueSize = sizeof(T) * 8; + static std::size_t constexpr kOne = 1; // force correct data type + + Span bits_; + + public: + BitField() = default; + __host__ __device__ explicit BitField(Span bits) : bits_(bits) {} + __host__ __device__ BitField(const BitField& other) : bits_(other.bits_) {} + BitField& operator=(const BitField& other) = default; + BitField& operator=(BitField&& other) = default; + __host__ __device__ bool Check(std::size_t pos) const + { + T bitmask = kOne << (pos % kValueSize); + return static_cast(bits_[pos / kValueSize] & bitmask); + } + __host__ __device__ void Set(std::size_t pos) + { + T bitmask = kOne << (pos % kValueSize); + bits_[pos / kValueSize] |= bitmask; + } + __host__ __device__ void Intersect(const BitField other) + { + if (bits_.Data() == other.bits_.Data()) { return; } + std::size_t size = min(bits_.Size(), other.bits_.Size()); + for (std::size_t i = 0; i < size; ++i) { + bits_[i] &= other.bits_[i]; + } + if (bits_.Size() > size) { + for (std::size_t i = size; i < bits_.Size(); ++i) { + bits_[i] = 0; + } + } + } + __host__ __device__ std::size_t Size() const { return kValueSize * bits_.Size(); } + __host__ static std::size_t ComputeStorageSize(std::size_t n_cat) + { + return n_cat / kValueSize + (n_cat % kValueSize != 0); + } + + static_assert(!std::is_signed::value, "Must use unsiged type as underlying storage."); +}; + template struct SplitCondition { SplitCondition() = default; From 9528e8856f2c5f23a1cba54cb686998ba3600e9a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 06:42:56 +0000 Subject: [PATCH 27/38] Make the tree walk logic generic --- cpp/src/explainer/tree_shap.cu | 222 ++++++++++++++++++--------------- 1 file changed, 118 insertions(+), 104 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 963713c2ab..cad2a0c0bf 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -175,6 +175,70 @@ struct SplitCondition { "ThresholdType must be a float or double"); }; +template +struct PathSegmentExtractor { + using PathElementT = gpu_treeshap::PathElement>; + std::vector& path_segments; + std::size_t& path_idx; + + static constexpr ThresholdType inf{std::numeric_limits::infinity()}; + + PathSegmentExtractor(std::vector& path_segments, std::size_t& path_idx) + : path_segments(path_segments), path_idx(path_idx) + { + } + + void node_handler(const tl::Tree& tree, + int child_idx, + int parent_idx, + int group_id, + float v) + { + constexpr auto inf = std::numeric_limits::infinity(); + double zero_fraction = 1.0; + bool has_count_info = false; + if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + has_count_info = true; + } + if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + zero_fraction = static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + has_count_info = true; + } + if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } + // Encode the range of feature values that flow down this path + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + RAFT_FAIL( + "Only trees with numerical splits are supported. " + "Trees with categorical splits are not supported yet."); + } + ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + auto comparison_op = tree.ComparisonOp(parent_idx); + path_segments.push_back(gpu_treeshap::PathElement>{ + path_idx, + tree.SplitIndex(parent_idx), + group_id, + SplitCondition{lower_bound, upper_bound, comparison_op}, + zero_fraction, + v}); + } + + void root_handler(const tl::Tree& tree, + int child_idx, + int group_id, + float v) + { + // Root node has feature -1 + auto comparison_op = tree.ComparisonOp(child_idx); + path_segments.push_back(gpu_treeshap::PathElement>{ + path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); + } + + void new_path_handler() { ++path_idx; } +}; + template class TreePathInfoImpl : public ML::Explainer::TreePathInfo { public: @@ -266,76 +330,33 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, namespace ML { namespace Explainer { - -// Traverse a path from the root node to a leaf node and return the list of the path segments -// Note: the path segments will have missing values in path_idx, group_id and v (leaf value). -// The callser is responsible for filling in these fields. -template -std::vector>> traverse_towards_leaf_node( - const tl::Tree& tree, - int leaf_node_id, - const std::vector& parent_id) +// Traverse a path from the root node to a leaf node and call the handler functions for each node. +// The fields group_id and v (leaf value) will be passed to the handler. +template +void traverse_towards_leaf_node(const tl::Tree& tree, + int leaf_node_id, + int group_id, + float v, + const std::vector& parent_id, + PathHandler& path_handler) { - std::vector>> path_segments; - int child_idx = leaf_node_id; - int parent_idx = parent_id[child_idx]; - constexpr auto inf = std::numeric_limits::infinity(); - tl::Operator comparison_op = tl::Operator::kNone; + int child_idx = leaf_node_id; + int parent_idx = parent_id[child_idx]; while (parent_idx != -1) { - double zero_fraction = 1.0; - bool has_count_info = false; - if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); - has_count_info = true; - } - if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); - has_count_info = true; - } - if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } - // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); - } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - comparison_op = tree.ComparisonOp(parent_idx); - path_segments.push_back(gpu_treeshap::PathElement>{ - ~std::size_t(0), - tree.SplitIndex(parent_idx), - -1, - SplitCondition{lower_bound, upper_bound, comparison_op}, - zero_fraction, - std::numeric_limits::quiet_NaN()}); + path_handler.node_handler(tree, child_idx, parent_idx, group_id, v); child_idx = parent_idx; parent_idx = parent_id[child_idx]; } - // Root node has feature -1 - comparison_op = tree.ComparisonOp(child_idx); - // Build temporary path segments with unknown path_idx, group_id and leaf value - path_segments.push_back(gpu_treeshap::PathElement>{ - ~std::size_t(0), - -1, - -1, - SplitCondition{-inf, inf, comparison_op}, - 1.0, - std::numeric_limits::quiet_NaN()}); - return path_segments; + path_handler.root_handler(tree, child_idx, group_id, v); } -// Extract the path segments from a single tree. Each path segment will have path_idx field, which -// uniquely identifies the path to which the segment belongs. The path_idx_offset parameter sets -// the path_idx field of the first path segment. -template -std::vector>> -extract_path_segments_from_tree(const std::vector>& tree_list, - std::size_t tree_idx, - bool use_vector_leaf, - int num_groups, - std::size_t path_idx_offset) +// Visit every path segments in a single tree and call handler functions for each segment. +template +void visit_path_segments_in_tree(const std::vector>& tree_list, + std::size_t tree_idx, + bool use_vector_leaf, + int num_groups, + PathHandler& path_handler) { if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } @@ -350,27 +371,10 @@ extract_path_segments_from_tree(const std::vector>> path_segments; - for (int nid = 0; nid < tree.num_nodes; nid++) { if (tree.IsLeaf(nid)) { // For each leaf node... // Extract path segments by traversing the path from the leaf node to the root node - auto path_to_leaf = traverse_towards_leaf_node(tree, nid, parent_id); - // If use_vector_leaf=True: - // * Duplicate the path segments N times, where N = num_groups - // * Insert the duplicated path segments into path_segments - // If use_vector_leaf=False: - // * Insert the path segments into path_segments - auto path_insertor = [&path_to_leaf, &path_segments]( - auto leaf_value, auto path_idx, int group_id) { - for (auto& e : path_to_leaf) { - e.path_idx = path_idx; - e.v = static_cast(leaf_value); - e.group = group_id; - } - path_segments.insert(path_segments.end(), path_to_leaf.cbegin(), path_to_leaf.cend()); - }; + // If use_vector_leaf=True, repeat the path segments N times, where N = num_groups if (use_vector_leaf) { auto leaf_vector = tree.LeafVector(nid); if (leaf_vector.size() != static_cast(num_groups)) { @@ -379,18 +383,40 @@ extract_path_segments_from_tree(const std::vector(leaf_vector.size())); } for (int group_id = 0; group_id < num_groups; ++group_id) { - path_insertor(leaf_vector[group_id], path_idx, group_id); - path_idx++; + traverse_towards_leaf_node( + tree, nid, group_id, leaf_vector[group_id], parent_id, path_handler); + path_handler.new_path_handler(); } } else { - auto leaf_value = tree.LeafValue(nid); int group_id = static_cast(tree_idx) % num_groups; - path_insertor(leaf_value, path_idx, group_id); - path_idx++; + auto leaf_value = tree.LeafValue(nid); + traverse_towards_leaf_node(tree, nid, group_id, leaf_value, parent_id, path_handler); + path_handler.new_path_handler(); } } } - return path_segments; +} + +// Visit every path segments in the whole tree ensemble model +template +void visit_path_segments_in_model(const tl::ModelImpl& model, + PathHandler& path_handler) +{ + int num_groups = 1; + bool use_vector_leaf; + if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } + if (model.task_type == tl::TaskType::kBinaryClfRegr || + model.task_type == tl::TaskType::kMultiClfGrovePerClass) { + use_vector_leaf = false; + } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { + use_vector_leaf = true; + } else { + RAFT_FAIL("Unsupported task_type: %d", static_cast(model.task_type)); + } + + for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) { + visit_path_segments_in_tree(model.trees, tree_idx, use_vector_leaf, num_groups, path_handler); + } } template @@ -407,24 +433,12 @@ std::unique_ptr extract_path_info_impl( std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); - int num_groups = 1; - bool use_vector_leaf; - if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } - if (model.task_type == tl::TaskType::kBinaryClfRegr || - model.task_type == tl::TaskType::kMultiClfGrovePerClass) { - use_vector_leaf = false; - } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { - use_vector_leaf = true; - } else { - RAFT_FAIL("Unsupported task_type: %d", static_cast(model.task_type)); - } + // Each path segment will have path_idx field, which uniquely identifies the path to which the + // segment belongs. std::size_t path_idx = 0; - for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) { - auto path_segments = - extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, num_groups, path_idx); - path_info->paths.insert(path_info->paths.end(), path_segments.cbegin(), path_segments.cend()); - if (!path_segments.empty()) { path_idx = path_segments.back().path_idx + 1; } - } + PathSegmentExtractor path_extractor{path_info->paths, path_idx}; + visit_path_segments_in_model(model, path_extractor); + path_info->global_bias = model.param.global_bias; path_info->task_type = model.task_type; path_info->task_param = model.task_param; From c1061ad8cf2b9f365dcdc0d2e2eda793baec7d84 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 20:25:15 +0000 Subject: [PATCH 28/38] First working prototype --- cpp/include/cuml/explainer/tree_shap.hpp | 2 +- cpp/src/explainer/tree_shap.cu | 202 +++++++++++++++--- .../cuml/experimental/explainer/tree_shap.pyx | 2 +- .../cuml/test/explainer/test_gpu_treeshap.py | 66 ++++++ 4 files changed, 245 insertions(+), 27 deletions(-) diff --git a/cpp/include/cuml/explainer/tree_shap.hpp b/cpp/include/cuml/explainer/tree_shap.hpp index 9ce8123073..e2d3355baa 100644 --- a/cpp/include/cuml/explainer/tree_shap.hpp +++ b/cpp/include/cuml/explainer/tree_shap.hpp @@ -34,7 +34,7 @@ class TreePathInfo { }; std::unique_ptr extract_path_info(ModelHandle model); -void gpu_treeshap(const TreePathInfo* path_info, +void gpu_treeshap(TreePathInfo* path_info, const float* data, std::size_t n_rows, std::size_t n_cols, diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index cad2a0c0bf..890efb5fcd 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -16,12 +16,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -125,19 +127,39 @@ class BitField { { return n_cat / kValueSize + (n_cat % kValueSize != 0); } + __host__ std::string ToString(bool reverse = false) const + { + std::ostringstream oss; + oss << "Bits storage size: " << bits_.Size() << ", elements: "; + for (auto i = 0; i < bits_.Size(); ++i) { + std::bitset bset(bits_[i]); + std::string s = bset.to_string(); + if (reverse) { std::reverse(s.begin(), s.end()); } + oss << s << ", "; + } + return oss.str(); + } static_assert(!std::is_signed::value, "Must use unsiged type as underlying storage."); }; +using CatBitFieldStorageT = std::uint32_t; +using CatBitField = BitField; +using CatT = std::uint32_t; + template struct SplitCondition { SplitCondition() = default; SplitCondition(ThresholdType feature_lower_bound, ThresholdType feature_upper_bound, - tl::Operator comparison_op) + bool is_missing_branch, + tl::Operator comparison_op, + CatBitField categories) : feature_lower_bound(feature_lower_bound), feature_upper_bound(feature_upper_bound), - comparison_op(comparison_op) + is_missing_branch(is_missing_branch), + comparison_op(comparison_op), + categories(categories) { if (feature_lower_bound > feature_upper_bound) { RAFT_FAIL("Lower bound cannot exceed upper bound"); @@ -151,13 +173,23 @@ struct SplitCondition { // Lower and upper bounds on feature values flowing down this path ThresholdType feature_lower_bound; ThresholdType feature_upper_bound; + bool is_missing_branch; // Comparison operator used in the test. For now only < (kLT) and <= (kLE) // are supported. tl::Operator comparison_op; + CatBitField categories; // Does this instance flow down this path? __host__ __device__ bool EvaluateSplit(ThresholdType x) const { + static_assert(std::is_floating_point::value, "x must be a floating point type"); + auto max_representable_int = + static_cast(uint64_t(1) << std::numeric_limits::digits); + if (isnan(x)) { return is_missing_branch; } + if (categories.Size() != 0) { + if (x < 0 || std::fabs(x) > max_representable_int) { return false; } + return categories.Check(static_cast(x)); + } if (comparison_op == tl::Operator::kLE) { return x > feature_lower_bound && x <= feature_upper_bound; } @@ -167,24 +199,74 @@ struct SplitCondition { // Combine two split conditions on the same feature __host__ __device__ void Merge(const SplitCondition& other) { // Combine duplicate features - feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); - feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + if (categories.Size() != 0 || other.categories.Size() != 0) { + categories.Intersect(other.categories); + } else { + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + } + is_missing_branch = is_missing_branch && other.is_missing_branch; } static_assert(std::is_same::value || std::is_same::value, "ThresholdType must be a float or double"); }; +template +struct CategoricalSplitCounter { + int n_features; + std::vector n_categories; + // n_categories[K] = number of category values for feature K + // Set to 0 for numerical features + std::vector feature_id; + // feature_id[I] = feature ID associated with the I-th path segment + + CategoricalSplitCounter(int n_features) + : n_features(n_features), n_categories(n_features, 0), feature_id() + { + } + + void node_handler(const tl::Tree& tree, int, int parent_idx, int, float) + { + const auto split_index = tree.SplitIndex(parent_idx); + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + CatT max_cat = 0; + for (CatT cat : tree.MatchingCategories(parent_idx)) { + if (cat > max_cat) { max_cat = cat; } + } + n_categories[split_index] = std::max(n_categories[split_index], max_cat + 1); + } + feature_id.push_back(split_index); + } + + void root_handler(const tl::Tree&, int, int, float) + { + feature_id.push_back(-1); + } + + void new_path_handler() {} +}; + template struct PathSegmentExtractor { using PathElementT = gpu_treeshap::PathElement>; std::vector& path_segments; std::size_t& path_idx; + std::vector& cat_bitfields; + const std::vector& bitfield_segments; + std::size_t path_segment_idx; static constexpr ThresholdType inf{std::numeric_limits::infinity()}; - PathSegmentExtractor(std::vector& path_segments, std::size_t& path_idx) - : path_segments(path_segments), path_idx(path_idx) + PathSegmentExtractor(std::vector& path_segments, + std::size_t& path_idx, + std::vector& cat_bitfields, + const std::vector& bitfield_segments) + : path_segments(path_segments), + path_idx(path_idx), + cat_bitfields(cat_bitfields), + bitfield_segments(bitfield_segments), + path_segment_idx(0) { } @@ -194,7 +276,6 @@ struct PathSegmentExtractor { int group_id, float v) { - constexpr auto inf = std::numeric_limits::infinity(); double zero_fraction = 1.0; bool has_count_info = false; if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { @@ -207,22 +288,48 @@ struct PathSegmentExtractor { } if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + bool is_missing_branch = tree.DefaultChild(parent_idx) == child_idx; + auto split_type = tree.SplitType(parent_idx); + ThresholdType lower_bound, upper_bound; + tl::Operator comparison_op; + CatBitField categories; + if (split_type == tl::SplitFeatureType::kCategorical) { + auto n_bitfields = + bitfield_segments[path_segment_idx + 1] - bitfield_segments[path_segment_idx]; + categories = CatBitField(Span(cat_bitfields) + .Subspan(bitfield_segments[path_segment_idx], n_bitfields)); + for (CatT cat : tree.MatchingCategories(parent_idx)) { + categories.Set(static_cast(cat)); + } + if (is_left_path) { + for (std::size_t i = bitfield_segments[path_segment_idx]; + i < bitfield_segments[path_segment_idx + 1]; + ++i) { + cat_bitfields[i] = ~cat_bitfields[i]; + } + } + lower_bound = -inf; + upper_bound = inf; + comparison_op = tl::Operator::kNone; + } else { + if (split_type != tl::SplitFeatureType::kNumerical) { + // Assume: split is either numerical or categorical + RAFT_FAIL("Unexpected split type: %d", static_cast(split_type)); + } + categories = CatBitField{}; + lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + comparison_op = tree.ComparisonOp(parent_idx); } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - auto comparison_op = tree.ComparisonOp(parent_idx); path_segments.push_back(gpu_treeshap::PathElement>{ path_idx, tree.SplitIndex(parent_idx), group_id, - SplitCondition{lower_bound, upper_bound, comparison_op}, + SplitCondition{lower_bound, upper_bound, is_missing_branch, comparison_op, categories}, zero_fraction, v}); + ++path_segment_idx; } void root_handler(const tl::Tree& tree, @@ -233,7 +340,8 @@ struct PathSegmentExtractor { // Root node has feature -1 auto comparison_op = tree.ComparisonOp(child_idx); path_segments.push_back(gpu_treeshap::PathElement>{ - path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); + path_idx, -1, group_id, SplitCondition{-inf, inf, false, comparison_op, {}}, 1.0, v}); + ++path_segment_idx; } void new_path_handler() { ++path_idx; } @@ -248,7 +356,11 @@ class TreePathInfoImpl : public ML::Explainer::TreePathInfo { tl::TaskType task_type; tl::TaskParam task_param; bool average_tree_output; - std::vector>> paths; + std::vector>> path_segments; + std::vector categorical_bitfields; + std::vector bitfield_segments; + // bitfield_segments[I]: cumulative total count of all bit fields for path segments + // 0, 1, ..., I-1 static_assert(std::is_same::value || std::is_same::value, @@ -287,12 +399,25 @@ class DenseDatasetWrapper { }; template -void gpu_treeshap_impl(const TreePathInfoImpl* path_info, +void gpu_treeshap_impl(TreePathInfoImpl* path_info, const float* data, std::size_t n_rows, std::size_t n_cols, float* out_preds) { + // Marshall bit fields to GPU memory + auto& categorical_bitfields = path_info->categorical_bitfields; + auto& path_segments = path_info->path_segments; + auto& bitfield_segments = path_info->bitfield_segments; + thrust::device_vector d_cat_bitfields(categorical_bitfields.cbegin(), + categorical_bitfields.cend()); + for (std::size_t path_seg_idx = 0; path_seg_idx < path_segments.size(); ++path_seg_idx) { + auto n_bitfields = bitfield_segments[path_seg_idx + 1] - bitfield_segments[path_seg_idx]; + path_segments[path_seg_idx].split_condition.categories = + CatBitField(Span(d_cat_bitfields) + .Subspan(bitfield_segments[path_seg_idx], n_bitfields)); + } + DenseDatasetWrapper X(data, n_rows, n_cols); std::size_t num_groups = 1; @@ -303,8 +428,8 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, thrust::device_ptr out_preds_ptr = thrust::device_pointer_cast(out_preds); gpu_treeshap::GPUTreeShap(X, - path_info->paths.begin(), - path_info->paths.end(), + path_segments.begin(), + path_segments.end(), num_groups, out_preds_ptr, out_preds_ptr + pred_size); @@ -433,10 +558,37 @@ std::unique_ptr extract_path_info_impl( std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); + /* 1. Scan the model for categorical splits and pre-allocate bit fields. */ + CategoricalSplitCounter cat_counter{model.num_feature}; + visit_path_segments_in_model(model, cat_counter); + + std::size_t n_path_segments = cat_counter.feature_id.size(); + std::vector n_bitfields(n_path_segments, 0); + // n_bitfields[I] : number of bit fields for path segment I + + std::transform(cat_counter.feature_id.cbegin(), + cat_counter.feature_id.cend(), + n_bitfields.begin(), + [&](std::int64_t fid) -> std::size_t { + if (fid == -1) { return 0; } + return CatBitField::ComputeStorageSize(cat_counter.n_categories[fid]); + }); + + path_info->bitfield_segments = std::vector(n_path_segments + 1, 0); + std::inclusive_scan( + n_bitfields.cbegin(), n_bitfields.cend(), path_info->bitfield_segments.begin() + 1); + + path_info->categorical_bitfields = + std::vector(path_info->bitfield_segments.back(), 0); + + /* 2. Scan the model again, to extract path segments. */ // Each path segment will have path_idx field, which uniquely identifies the path to which the // segment belongs. std::size_t path_idx = 0; - PathSegmentExtractor path_extractor{path_info->paths, path_idx}; + PathSegmentExtractor path_extractor{path_info->path_segments, + path_idx, + path_info->categorical_bitfields, + path_info->bitfield_segments}; visit_path_segments_in_model(model, path_extractor); path_info->global_bias = model.param.global_bias; @@ -458,7 +610,7 @@ std::unique_ptr extract_path_info(ModelHandle model) }); } -void gpu_treeshap(const TreePathInfo* path_info, +void gpu_treeshap(TreePathInfo* path_info, const float* data, std::size_t n_rows, std::size_t n_cols, @@ -466,12 +618,12 @@ void gpu_treeshap(const TreePathInfo* path_info, { switch (path_info->GetThresholdType()) { case TreePathInfo::ThresholdTypeEnum::kDouble: { - const auto* path_info_casted = dynamic_cast*>(path_info); + auto* path_info_casted = dynamic_cast*>(path_info); gpu_treeshap_impl(path_info_casted, data, n_rows, n_cols, out_preds); } break; case TreePathInfo::ThresholdTypeEnum::kFloat: default: { - const auto* path_info_casted = dynamic_cast*>(path_info); + auto* path_info_casted = dynamic_cast*>(path_info); gpu_treeshap_impl(path_info_casted, data, n_rows, n_cols, out_preds); } break; } diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index 8e0d8b29dd..f62835f299 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -50,7 +50,7 @@ cdef extern from "cuml/explainer/tree_shap.hpp" namespace "ML::Explainer": pass cdef unique_ptr[TreePathInfo] extract_path_info(ModelHandle model) except + - cdef void gpu_treeshap(const TreePathInfo* path_info, + cdef void gpu_treeshap(TreePathInfo* path_info, const float* data, size_t n_rows, size_t n_cols, diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 5d790185f1..a0e940e547 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -17,6 +17,7 @@ import pytest import treelite import numpy as np +import pandas as pd import cupy as cp import cudf from cuml.experimental.explainer.tree_shap import TreeExplainer @@ -258,3 +259,68 @@ def test_sklearn_rf_classifier(n_classes): np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, expected_value, decimal=5) + + +@pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") +def test_xgb_toy_categorical(): + X = pd.DataFrame({'dummy': np.zeros(5, dtype=np.float32), + 'x': np.array([0, 1, 2, 3, 4], dtype=np.int32)}) + y = np.array([0, 0, 1, 1, 1], dtype=np.float32) + X['x'] = X['x'].astype("category") + dtrain = xgb.DMatrix(X, y, enable_categorical=True) + params = {"tree_method": "gpu_hist", "eval_metric": "error", + "objective": "binary:logistic", "max_depth": 2, + "min_child_weight": 0, "lambda": 0} + xgb_model = xgb.train(params, dtrain, num_boost_round=1, + evals=[(dtrain, 'train')]) + explainer = TreeExplainer(model=xgb_model) + out = explainer.shap_values(X) + + ref_out = xgb_model.predict(dtrain, pred_contribs=True) + np.testing.assert_almost_equal(out, ref_out[:, :-1], decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, ref_out[0, -1], + decimal=5) + + +@pytest.mark.parametrize('n_classes', [2, 3]) +@pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_xgb_categorical(n_classes): + n_samples = 100 + n_features = 8 + X, y = make_classification(n_samples=n_samples, n_features=n_features, + n_informative=n_features, n_redundant=0, + n_repeated=0, n_classes=n_classes, + random_state=2021) + X, y = X.astype(np.float32), y.astype(np.float32) + + # Turn the first 4 columns into categorical columns + X = pd.DataFrame({f'f{i}': X[:, i] for i in range(n_features)}) + for i in range(4): + column = f'f{i}' + X[column] = pd.qcut(X[column], 4, labels=range(4)) + + dtrain = xgb.DMatrix(X, y, enable_categorical=True) + params = {"tree_method": "gpu_hist", "max_depth": 6, + "base_score": 0.5, "seed": 0, "predictor": "gpu_predictor"} + if n_classes == 2: + params["objective"] = "binary:logistic" + params["eval_metric"] = "logloss" + else: + params["objective"] = "multi:softprob" + params["eval_metric"] = "mlogloss" + params["num_class"] = n_classes + xgb_model = xgb.train(params, dtrain, num_boost_round=10, + evals=[(dtrain, 'train')]) + explainer = TreeExplainer(model=xgb_model) + out = explainer.shap_values(X) + + ref_out = xgb_model.predict(dtrain, pred_contribs=True) + if n_classes == 2: + ref_out, ref_expected_value = ref_out[:, :-1], ref_out[0, -1] + else: + ref_out = ref_out.transpose((1, 0, 2)) + ref_out, ref_expected_value = ref_out[:, :, :-1], ref_out[:, 0, -1] + np.testing.assert_almost_equal(out, ref_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_expected_value, decimal=5) From b7c2e112656f2af391d06d948e221dc48bf2dc51 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 20:51:26 +0000 Subject: [PATCH 29/38] Add more XGBoost tests --- .../cuml/test/explainer/test_gpu_treeshap.py | 89 ++++++++++++++++--- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index a0e940e547..ac45fa258d 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -14,6 +14,7 @@ # limitations under the License. # +import json import pytest import treelite import numpy as np @@ -36,6 +37,48 @@ from sklearn.ensemble import RandomForestClassifier as sklrfc +def make_classification_with_categorical(*, n_samples, n_features, + n_categorical, n_informative, n_redundant, n_repeated, n_classes, + random_state): + X, y = make_classification(n_samples=n_samples, n_features=n_features, + n_informative=n_informative, + n_redundant=n_redundant, n_repeated=n_repeated, + n_classes=n_classes, random_state=random_state) + X, y = X.astype(np.float32), y.astype(np.float32) + + # Turn some columns into categorical, by taking quartiles + X = pd.DataFrame({f'f{i}': X[:, i] for i in range(n_features)}) + for i in range(n_categorical): + column = f'f{i}' + X[column] = pd.qcut(X[column], 4, labels=range(4)) + return X, y + + +def make_regression_with_categorical(*, n_samples, n_features, n_categorical, + n_informative, random_state): + X, y = make_regression(n_samples=n_samples, n_features=n_features, + n_informative=n_informative, n_targets=1, + random_state=random_state) + X, y = X.astype(np.float32), y.astype(np.float32) + + # Turn some columns into categorical, by taking quartiles + X = pd.DataFrame({f'f{i}': X[:, i] for i in range(n_features)}) + for i in range(n_categorical): + column = f'f{i}' + X[column] = pd.qcut(X[column], 4, labels=range(4)) + return X, y + + +def count_categorical_splits(tl_model): + model_dump = json.loads(tl_model.dump_as_json(pretty_print=False)) + count = 0 + for tree in model_dump["trees"]: + for node in tree["nodes"]: + if "split_type" in node and node["split_type"] == "categorical": + count += 1 + return count + + @pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', 'reg:squaredlogerror', 'reg:pseudohubererror']) @@ -285,20 +328,13 @@ def test_xgb_toy_categorical(): @pytest.mark.parametrize('n_classes', [2, 3]) @pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") -def test_xgb_categorical(n_classes): +def test_xgb_classifier_with_categorical(n_classes): n_samples = 100 n_features = 8 - X, y = make_classification(n_samples=n_samples, n_features=n_features, - n_informative=n_features, n_redundant=0, - n_repeated=0, n_classes=n_classes, - random_state=2021) - X, y = X.astype(np.float32), y.astype(np.float32) - - # Turn the first 4 columns into categorical columns - X = pd.DataFrame({f'f{i}': X[:, i] for i in range(n_features)}) - for i in range(4): - column = f'f{i}' - X[column] = pd.qcut(X[column], 4, labels=range(4)) + X, y = make_classification_with_categorical(n_samples=n_samples, + n_features=n_features, n_categorical=4, n_informative=n_features, + n_redundant=0, n_repeated=0, n_classes=n_classes, + random_state=2022) dtrain = xgb.DMatrix(X, y, enable_categorical=True) params = {"tree_method": "gpu_hist", "max_depth": 6, @@ -312,6 +348,8 @@ def test_xgb_categorical(n_classes): params["num_class"] = n_classes xgb_model = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtrain, 'train')]) + assert count_categorical_splits(treelite.Model.from_xgboost(xgb_model)) > 0 + explainer = TreeExplainer(model=xgb_model) out = explainer.shap_values(X) @@ -324,3 +362,30 @@ def test_xgb_categorical(n_classes): np.testing.assert_almost_equal(out, ref_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_expected_value, decimal=5) + + +@pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_xgb_regressor_with_categorical(): + n_samples = 100 + n_features = 8 + X, y = make_regression_with_categorical(n_samples=n_samples, + n_features=n_features, n_categorical=4, n_informative=n_features, + random_state=2022) + + dtrain = xgb.DMatrix(X, y, enable_categorical=True) + params = {"tree_method": "gpu_hist", "max_depth": 6, + "base_score": 0.5, "seed": 0, "predictor": "gpu_predictor", + "objective": "reg:squarederror", "eval_metric": "rmse"} + xgb_model = xgb.train(params, dtrain, num_boost_round=10, + evals=[(dtrain, 'train')]) + assert count_categorical_splits(treelite.Model.from_xgboost(xgb_model)) > 0 + + explainer = TreeExplainer(model=xgb_model) + out = explainer.shap_values(X) + + ref_out = xgb_model.predict(dtrain, pred_contribs=True) + ref_out, ref_expected_value = ref_out[:, :-1], ref_out[0, -1] + np.testing.assert_almost_equal(out, ref_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_expected_value, decimal=5) From 46f2cf840a85317e653c54715ab7c0e2ab55e21d Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 21:03:17 +0000 Subject: [PATCH 30/38] Update copyright year --- cpp/include/cuml/explainer/tree_shap.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuml/explainer/tree_shap.hpp b/cpp/include/cuml/explainer/tree_shap.hpp index e2d3355baa..fb928b5387 100644 --- a/cpp/include/cuml/explainer/tree_shap.hpp +++ b/cpp/include/cuml/explainer/tree_shap.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 22caa8c9db2b7557b96ea0ecd71a3a026931f95a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 21:47:09 +0000 Subject: [PATCH 31/38] Add support for LightGBM --- cpp/src/explainer/tree_shap.cu | 3 +- .../cuml/experimental/explainer/tree_shap.pyx | 7 +- .../cuml/test/explainer/test_gpu_treeshap.py | 81 +++++++++++++++++-- 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 890efb5fcd..1045499e2d 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -302,7 +302,8 @@ struct PathSegmentExtractor { for (CatT cat : tree.MatchingCategories(parent_idx)) { categories.Set(static_cast(cat)); } - if (is_left_path) { + bool use_right = tree.CategoriesListRightChild(parent_idx); + if ((use_right && is_left_path) || (!use_right && !is_left_path)) { for (std::size_t i = bitfield_segments[path_segment_idx]; i < bitfield_segments[path_segment_idx + 1]; ++i) { diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index f62835f299..550623e203 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -26,6 +26,7 @@ from cuml.ensemble import RandomForestClassifier as curfc from libcpp.memory cimport unique_ptr from libc.stdint cimport uintptr_t from libcpp.utility cimport move +import re import numpy as np import treelite @@ -111,9 +112,13 @@ class TreeExplainer: # Handle various kinds of tree model objects cls = model.__class__ # XGBoost model object - if cls.__module__ == 'xgboost.core' and cls.__name__ == 'Booster': + if re.match(r'xgboost.*$', cls.__module__) and cls.__name__ == 'Booster': model = treelite.Model.from_xgboost(model) handle = model.handle.value + # LightGBM model object + if re.match(r'lightgbm.*$', cls.__module__) and cls.__name__ == 'Booster': + model = treelite.Model.from_lightgbm(model) + handle = model.handle.value # cuML RF model object elif isinstance(model, (curfr, curfc)): try: diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index ac45fa258d..1182009def 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -22,13 +22,15 @@ import cupy as cp import cudf from cuml.experimental.explainer.tree_shap import TreeExplainer -from cuml.common.import_utils import has_xgboost, has_shap, has_sklearn +from cuml.common.import_utils import has_xgboost, has_lightgbm, has_shap, has_sklearn from cuml.common.exceptions import NotFittedError from cuml.ensemble import RandomForestRegressor as curfr from cuml.ensemble import RandomForestClassifier as curfc if has_xgboost(): import xgboost as xgb +if has_lightgbm(): + import lightgbm as lgb if has_shap(): import shap if has_sklearn(): @@ -69,7 +71,7 @@ def make_regression_with_categorical(*, n_samples, n_features, n_categorical, return X, y -def count_categorical_splits(tl_model): +def count_categorical_split(tl_model): model_dump = json.loads(tl_model.dump_as_json(pretty_print=False)) count = 0 for tree in model_dump["trees"]: @@ -109,7 +111,7 @@ def test_xgb_regressor(objective): correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, - ref_explainer.expected_value) + ref_explainer.expected_value, decimal=5) @pytest.mark.parametrize('objective,n_classes', @@ -153,7 +155,7 @@ def test_xgb_classifier(objective, n_classes): correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, - ref_explainer.expected_value) + ref_explainer.expected_value, decimal=5) def test_degenerate_cases(): @@ -348,7 +350,7 @@ def test_xgb_classifier_with_categorical(n_classes): params["num_class"] = n_classes xgb_model = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtrain, 'train')]) - assert count_categorical_splits(treelite.Model.from_xgboost(xgb_model)) > 0 + assert count_categorical_split(treelite.Model.from_xgboost(xgb_model)) > 0 explainer = TreeExplainer(model=xgb_model) out = explainer.shap_values(X) @@ -379,7 +381,7 @@ def test_xgb_regressor_with_categorical(): "objective": "reg:squarederror", "eval_metric": "rmse"} xgb_model = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtrain, 'train')]) - assert count_categorical_splits(treelite.Model.from_xgboost(xgb_model)) > 0 + assert count_categorical_split(treelite.Model.from_xgboost(xgb_model)) > 0 explainer = TreeExplainer(model=xgb_model) out = explainer.shap_values(X) @@ -389,3 +391,70 @@ def test_xgb_regressor_with_categorical(): np.testing.assert_almost_equal(out, ref_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_expected_value, decimal=5) + + +@pytest.mark.skipif(not has_lightgbm(), reason="need to install lightgbm") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_lightgbm_regressor_with_categorical(): + n_samples = 100 + n_features = 8 + n_categorical = 8 + X, y = make_regression_with_categorical(n_samples=n_samples, + n_features=n_features, n_categorical=n_categorical, + n_informative=n_features, random_state=2022) + + dtrain = lgb.Dataset(X, label=y, categorical_feature=range(n_categorical)) + params = {"num_leaves": 64, "seed": 0, "objective": "regression", + "metric": "rmse", "min_data_per_group": 1} + lgb_model = lgb.train(params, dtrain, num_boost_round=10, + valid_sets=[dtrain], valid_names=['train']) + assert count_categorical_split(treelite.Model.from_lightgbm(lgb_model)) > 0 + + explainer = TreeExplainer(model=lgb_model) + out = explainer.shap_values(X) + + ref_explainer = shap.explainers.Tree(model=lgb_model) + ref_out = ref_explainer.shap_values(X) + np.testing.assert_almost_equal(out, ref_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_explainer.expected_value, decimal=5) + + +@pytest.mark.parametrize('n_classes', [2, 3]) +@pytest.mark.skipif(not has_lightgbm(), reason="need to install lightgbm") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_lightgbm_classifier_with_categorical(n_classes): + n_samples = 100 + n_features = 8 + n_categorical = 8 + X, y = make_classification_with_categorical(n_samples=n_samples, + n_features=n_features, n_categorical=n_categorical, + n_informative=n_features, n_redundant=0, n_repeated=0, + n_classes=n_classes, random_state=2022) + + dtrain = lgb.Dataset(X, label=y, categorical_feature=range(n_categorical)) + params = {"num_leaves": 64, "seed": 0, "min_data_per_group": 1} + if n_classes == 2: + params["objective"] = "binary" + params["metric"] = "binary_logloss" + else: + params["objective"] = "multiclass" + params["metric"] = "multi_logloss" + params["num_class"] = n_classes + lgb_model = lgb.train(params, dtrain, num_boost_round=10, + valid_sets=[dtrain], valid_names=['train']) + assert count_categorical_split(treelite.Model.from_lightgbm(lgb_model)) > 0 + + explainer = TreeExplainer(model=lgb_model) + out = explainer.shap_values(X) + + ref_explainer = shap.explainers.Tree(model=lgb_model) + ref_out = np.array(ref_explainer.shap_values(X)) + if n_classes == 2: + ref_out = ref_out[1, :, :] + ref_expected_value = ref_explainer.expected_value[1] + else: + ref_expected_value = ref_explainer.expected_value + np.testing.assert_almost_equal(out, ref_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_expected_value, decimal=5) From fd23c6f08b18eaaec46e6c5b1cb8c0fdca21fe32 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 19 Jan 2022 21:53:57 +0000 Subject: [PATCH 32/38] Fix style --- .../cuml/experimental/explainer/tree_shap.pyx | 5 ++- .../cuml/test/explainer/test_gpu_treeshap.py | 43 ++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index 550623e203..e5df6de1c9 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -111,12 +111,13 @@ class TreeExplainer: def __init__(self, *, model): # Handle various kinds of tree model objects cls = model.__class__ + cls_module, cls_name = cls.__module__, cls.__name__ # XGBoost model object - if re.match(r'xgboost.*$', cls.__module__) and cls.__name__ == 'Booster': + if re.match(r'xgboost.*$', cls_module) and cls_name == 'Booster': model = treelite.Model.from_xgboost(model) handle = model.handle.value # LightGBM model object - if re.match(r'lightgbm.*$', cls.__module__) and cls.__name__ == 'Booster': + if re.match(r'lightgbm.*$', cls_module) and cls_name == 'Booster': model = treelite.Model.from_lightgbm(model) handle = model.handle.value # cuML RF model object diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 1182009def..639275060d 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -22,7 +22,8 @@ import cupy as cp import cudf from cuml.experimental.explainer.tree_shap import TreeExplainer -from cuml.common.import_utils import has_xgboost, has_lightgbm, has_shap, has_sklearn +from cuml.common.import_utils import has_xgboost, has_lightgbm, has_shap +from cuml.common.import_utils import has_sklearn from cuml.common.exceptions import NotFittedError from cuml.ensemble import RandomForestRegressor as curfr from cuml.ensemble import RandomForestClassifier as curfc @@ -39,9 +40,9 @@ from sklearn.ensemble import RandomForestClassifier as sklrfc -def make_classification_with_categorical(*, n_samples, n_features, - n_categorical, n_informative, n_redundant, n_repeated, n_classes, - random_state): +def make_classification_with_categorical( + *, n_samples, n_features, n_categorical, n_informative, n_redundant, + n_repeated, n_classes, random_state): X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_informative, n_redundant=n_redundant, n_repeated=n_repeated, @@ -56,8 +57,8 @@ def make_classification_with_categorical(*, n_samples, n_features, return X, y -def make_regression_with_categorical(*, n_samples, n_features, n_categorical, - n_informative, random_state): +def make_regression_with_categorical( + *, n_samples, n_features, n_categorical, n_informative, random_state): X, y = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_informative, n_targets=1, random_state=random_state) @@ -333,10 +334,10 @@ def test_xgb_toy_categorical(): def test_xgb_classifier_with_categorical(n_classes): n_samples = 100 n_features = 8 - X, y = make_classification_with_categorical(n_samples=n_samples, - n_features=n_features, n_categorical=4, n_informative=n_features, - n_redundant=0, n_repeated=0, n_classes=n_classes, - random_state=2022) + X, y = make_classification_with_categorical( + n_samples=n_samples, n_features=n_features, n_categorical=4, + n_informative=n_features, n_redundant=0, n_repeated=0, + n_classes=n_classes, random_state=2022) dtrain = xgb.DMatrix(X, y, enable_categorical=True) params = {"tree_method": "gpu_hist", "max_depth": 6, @@ -371,9 +372,9 @@ def test_xgb_classifier_with_categorical(n_classes): def test_xgb_regressor_with_categorical(): n_samples = 100 n_features = 8 - X, y = make_regression_with_categorical(n_samples=n_samples, - n_features=n_features, n_categorical=4, n_informative=n_features, - random_state=2022) + X, y = make_regression_with_categorical( + n_samples=n_samples, n_features=n_features, n_categorical=4, + n_informative=n_features, random_state=2022) dtrain = xgb.DMatrix(X, y, enable_categorical=True) params = {"tree_method": "gpu_hist", "max_depth": 6, @@ -399,9 +400,10 @@ def test_lightgbm_regressor_with_categorical(): n_samples = 100 n_features = 8 n_categorical = 8 - X, y = make_regression_with_categorical(n_samples=n_samples, - n_features=n_features, n_categorical=n_categorical, - n_informative=n_features, random_state=2022) + X, y = make_regression_with_categorical( + n_samples=n_samples, n_features=n_features, + n_categorical=n_categorical, n_informative=n_features, + random_state=2022) dtrain = lgb.Dataset(X, label=y, categorical_feature=range(n_categorical)) params = {"num_leaves": 64, "seed": 0, "objective": "regression", @@ -427,10 +429,11 @@ def test_lightgbm_classifier_with_categorical(n_classes): n_samples = 100 n_features = 8 n_categorical = 8 - X, y = make_classification_with_categorical(n_samples=n_samples, - n_features=n_features, n_categorical=n_categorical, - n_informative=n_features, n_redundant=0, n_repeated=0, - n_classes=n_classes, random_state=2022) + X, y = make_classification_with_categorical( + n_samples=n_samples, n_features=n_features, + n_categorical=n_categorical, n_informative=n_features, + n_redundant=0, n_repeated=0, n_classes=n_classes, + random_state=2022) dtrain = lgb.Dataset(X, label=y, categorical_feature=range(n_categorical)) params = {"num_leaves": 64, "seed": 0, "min_data_per_group": 1} From cdb46d3654025341d677e877904c88e31cc91349 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 25 Jan 2022 20:39:19 +0000 Subject: [PATCH 33/38] Remove temporary install step in build.sh --- ci/gpu/build.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 7b4e516c5d..a4634c2f37 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -64,11 +64,6 @@ gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid "rapids-doc-env=${MINOR_VERSION}.*" \ "shap>=0.37,<=0.39" -gpuci_conda_retry remove -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ - --force rapids-build-env rapids-notebook-env -gpuci_mamba_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvidia \ - treelite=2.2.1 - # https://docs.rapids.ai/maintainers/depmgmt/ # gpuci_conda_retry remove --force rapids-build-env rapids-notebook-env # gpuci_mamba_retry install -y "your-pkg=1.0.0" From e2345459f55403e786838ab067d9b514a9fc83da Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 26 Jan 2022 00:29:12 +0000 Subject: [PATCH 34/38] Respond to reviewer's comment --- cpp/src/explainer/tree_shap.cu | 7 +++++++ python/cuml/experimental/explainer/tree_shap.pyx | 6 ++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 1045499e2d..2fe25ed3b2 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -295,6 +295,10 @@ struct PathSegmentExtractor { tl::Operator comparison_op; CatBitField categories; if (split_type == tl::SplitFeatureType::kCategorical) { + /* Create bit fields to store the list of categories associated with this path. + The bit fields will be used to quickly decide whether a feature value should + flow down down this path or not. + The test in the test node is of form: x \in { list of category values } */ auto n_bitfields = bitfield_segments[path_segment_idx + 1] - bitfield_segments[path_segment_idx]; categories = CatBitField(Span(cat_bitfields) @@ -302,6 +306,9 @@ struct PathSegmentExtractor { for (CatT cat : tree.MatchingCategories(parent_idx)) { categories.Set(static_cast(cat)); } + // If this path is not the path that's taken the categorical test evaluates to be true, + // then flip all the bits in the bit fields. This step is needed because we first built + // the bit fields according to the list given in the categorical test. bool use_right = tree.CategoriesListRightChild(parent_idx); if ((use_right && is_left_path) || (!use_right && !is_left_path)) { for (std::size_t i = bitfield_segments[path_segment_idx]; diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index e5df6de1c9..f260d63f02 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -34,10 +34,8 @@ if has_sklearn(): from sklearn.ensemble import RandomForestRegressor as sklrfr from sklearn.ensemble import RandomForestClassifier as sklrfc else: - class PlaceHolder: - pass - sklrfr = PlaceHolder - sklrfc = PlaceHolder + sklrfr = object + sklrfc = object cdef extern from "treelite/c_api.h": ctypedef void* ModelHandle From 63416411dc6b7a1d0ac86ff8af8179417ce05842 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 26 Jan 2022 00:29:29 +0000 Subject: [PATCH 35/38] Make shap optional in tests --- python/cuml/test/explainer/test_gpu_treeshap.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 639275060d..8606deb344 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -396,6 +396,7 @@ def test_xgb_regressor_with_categorical(): @pytest.mark.skipif(not has_lightgbm(), reason="need to install lightgbm") @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +@pytest.mark.skipif(not has_shap(), reason="need to install shap") def test_lightgbm_regressor_with_categorical(): n_samples = 100 n_features = 8 @@ -425,6 +426,7 @@ def test_lightgbm_regressor_with_categorical(): @pytest.mark.parametrize('n_classes', [2, 3]) @pytest.mark.skipif(not has_lightgbm(), reason="need to install lightgbm") @pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +@pytest.mark.skipif(not has_shap(), reason="need to install shap") def test_lightgbm_classifier_with_categorical(n_classes): n_samples = 100 n_features = 8 From 9f6f0c9b55c68916ea0e8713fa5bdc09be088950 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 26 Jan 2022 08:07:25 +0000 Subject: [PATCH 36/38] Add coverage for missing values --- .../cuml/test/explainer/test_gpu_treeshap.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 8606deb344..267f14cf3b 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -105,11 +105,18 @@ def test_xgb_regressor(objective): evals=[(dtrain, 'train')]) tl_model = treelite.Model.from_xgboost(xgb_model) + # Insert NaN randomly into X + X_test = X.copy() + n_nan = int(np.floor(X.size * 0.1)) + rng = np.random.default_rng(seed=0) + index_nan = rng.choice(X.size, size=n_nan, replace=False) + X_test.ravel()[index_nan] = np.nan + explainer = TreeExplainer(model=tl_model) - out = explainer.shap_values(X) + out = explainer.shap_values(X_test) ref_explainer = shap.explainers.Tree(model=xgb_model) - correct_out = ref_explainer.shap_values(X) + correct_out = ref_explainer.shap_values(X_test) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_explainer.expected_value, decimal=5) @@ -149,11 +156,18 @@ def test_xgb_classifier(objective, n_classes): params['num_class'] = n_classes xgb_model = xgb.train(params, dtrain=dtrain, num_boost_round=num_round) + # Insert NaN randomly into X + X_test = X.copy() + n_nan = int(np.floor(X.size * 0.1)) + rng = np.random.default_rng(seed=0) + index_nan = rng.choice(X.size, size=n_nan, replace=False) + X_test.ravel()[index_nan] = np.nan + explainer = TreeExplainer(model=xgb_model) - out = explainer.shap_values(X) + out = explainer.shap_values(X_test) ref_explainer = shap.explainers.Tree(model=xgb_model) - correct_out = ref_explainer.shap_values(X) + correct_out = ref_explainer.shap_values(X_test) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, ref_explainer.expected_value, decimal=5) @@ -353,10 +367,18 @@ def test_xgb_classifier_with_categorical(n_classes): evals=[(dtrain, 'train')]) assert count_categorical_split(treelite.Model.from_xgboost(xgb_model)) > 0 + # Insert NaN randomly into X + X_test = X.values.copy() + n_nan = int(np.floor(X.size * 0.1)) + rng = np.random.default_rng(seed=0) + index_nan = rng.choice(X.size, size=n_nan, replace=False) + X_test.ravel()[index_nan] = np.nan + explainer = TreeExplainer(model=xgb_model) - out = explainer.shap_values(X) + out = explainer.shap_values(X_test) - ref_out = xgb_model.predict(dtrain, pred_contribs=True) + dtest = xgb.DMatrix(X_test) + ref_out = xgb_model.predict(dtest, pred_contribs=True, validate_features=False) if n_classes == 2: ref_out, ref_expected_value = ref_out[:, :-1], ref_out[0, -1] else: @@ -450,11 +472,18 @@ def test_lightgbm_classifier_with_categorical(n_classes): valid_sets=[dtrain], valid_names=['train']) assert count_categorical_split(treelite.Model.from_lightgbm(lgb_model)) > 0 + # Insert NaN randomly into X + X_test = X.values.copy() + n_nan = int(np.floor(X.size * 0.1)) + rng = np.random.default_rng(seed=0) + index_nan = rng.choice(X.size, size=n_nan, replace=False) + X_test.ravel()[index_nan] = np.nan + explainer = TreeExplainer(model=lgb_model) - out = explainer.shap_values(X) + out = explainer.shap_values(X_test) ref_explainer = shap.explainers.Tree(model=lgb_model) - ref_out = np.array(ref_explainer.shap_values(X)) + ref_out = np.array(ref_explainer.shap_values(X_test)) if n_classes == 2: ref_out = ref_out[1, :, :] ref_expected_value = ref_explainer.expected_value[1] From 48a6a836b2d3dab583b1b41f9c35064d54bff804 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 26 Jan 2022 11:27:53 -0800 Subject: [PATCH 37/38] Fix typo in comment --- cpp/src/explainer/tree_shap.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index 9785e60034..d597cec98c 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -306,7 +306,7 @@ struct PathSegmentExtractor { for (CatT cat : tree.MatchingCategories(parent_idx)) { categories.Set(static_cast(cat)); } - // If this path is not the path that's taken the categorical test evaluates to be true, + // If this path is not the path that's taken when the categorical test evaluates to be true, // then flip all the bits in the bit fields. This step is needed because we first built // the bit fields according to the list given in the categorical test. bool use_right = tree.CategoriesListRightChild(parent_idx); From f97106a5cf70e8cddd570d560ea8d99a03dd444b Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 26 Jan 2022 15:36:29 -0800 Subject: [PATCH 38/38] Fix style --- python/cuml/test/explainer/test_gpu_treeshap.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index 267f14cf3b..b601170a77 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -378,7 +378,8 @@ def test_xgb_classifier_with_categorical(n_classes): out = explainer.shap_values(X_test) dtest = xgb.DMatrix(X_test) - ref_out = xgb_model.predict(dtest, pred_contribs=True, validate_features=False) + ref_out = xgb_model.predict(dtest, pred_contribs=True, + validate_features=False) if n_classes == 2: ref_out, ref_expected_value = ref_out[:, :-1], ref_out[0, -1] else: