diff --git a/cpp/src/explainer/tree_shap.cu b/cpp/src/explainer/tree_shap.cu index dc243e364e..8ebe74efbb 100644 --- a/cpp/src/explainer/tree_shap.cu +++ b/cpp/src/explainer/tree_shap.cu @@ -15,10 +15,12 @@ */ #include +#include #include #include #include #include +#include #include #include #include @@ -135,8 +137,7 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, DenseDatasetWrapper X(data, n_rows, n_cols); std::size_t num_groups = 1; - if (path_info->task_type == tl::TaskType::kMultiClfGrovePerClass && - path_info->task_param.num_class > 1) { + if (path_info->task_param.num_class > 1) { num_groups = static_cast(path_info->task_param.num_class); } std::size_t pred_size = n_rows * num_groups * (n_cols + 1); @@ -171,6 +172,132 @@ void gpu_treeshap_impl(const TreePathInfoImpl* path_info, namespace ML { namespace Explainer { +// Traverse a path from the root node to a leaf node and return the list of the path segments +// Note: the path segments will have missing values in path_idx, group_id and v (leaf value). +// The callser is responsible for filling in these fields. +template +std::vector>> traverse_towards_leaf_node( + const tl::Tree& tree, + int leaf_node_id, + const std::vector& parent_id) +{ + std::vector>> path_segments; + int child_idx = leaf_node_id; + int parent_idx = parent_id[child_idx]; + constexpr auto inf = std::numeric_limits::infinity(); + tl::Operator comparison_op = tl::Operator::kNone; + while (parent_idx != -1) { + double zero_fraction = 1.0; + bool has_count_info = false; + if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { + zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); + has_count_info = true; + } + if (!has_count_info && tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { + zero_fraction = static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); + has_count_info = true; + } + if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } + // Encode the range of feature values that flow down this path + bool is_left_path = tree.LeftChild(parent_idx) == child_idx; + if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { + RAFT_FAIL( + "Only trees with numerical splits are supported. " + "Trees with categorical splits are not supported yet."); + } + ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); + ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; + comparison_op = tree.ComparisonOp(parent_idx); + path_segments.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + tree.SplitIndex(parent_idx), + -1, + SplitCondition{lower_bound, upper_bound, comparison_op}, + zero_fraction, + std::numeric_limits::quiet_NaN()}); + child_idx = parent_idx; + parent_idx = parent_id[child_idx]; + } + // Root node has feature -1 + comparison_op = tree.ComparisonOp(child_idx); + // Build temporary path segments with unknown path_idx, group_id and leaf value + path_segments.push_back(gpu_treeshap::PathElement>{ + ~std::size_t(0), + -1, + -1, + SplitCondition{-inf, inf, comparison_op}, + 1.0, + std::numeric_limits::quiet_NaN()}); + return path_segments; +} + +// Extract the path segments from a single tree. Each path segment will have path_idx field, which +// uniquely identifies the path to which the segment belongs. The path_idx_offset parameter sets +// the path_idx field of the first path segment. +template +std::vector>> +extract_path_segments_from_tree(const std::vector>& tree_list, + std::size_t tree_idx, + bool use_vector_leaf, + int num_groups, + std::size_t path_idx_offset) +{ + if (num_groups < 1) { RAFT_FAIL("num_groups must be at least 1"); } + + const tl::Tree& tree = tree_list[tree_idx]; + + // Compute parent ID of each node + std::vector parent_id(tree.num_nodes, -1); + for (int i = 0; i < tree.num_nodes; i++) { + if (!tree.IsLeaf(i)) { + parent_id[tree.LeftChild(i)] = i; + parent_id[tree.RightChild(i)] = i; + } + } + + std::size_t path_idx = path_idx_offset; + std::vector>> path_segments; + + for (int nid = 0; nid < tree.num_nodes; nid++) { + if (tree.IsLeaf(nid)) { // For each leaf node... + // Extract path segments by traversing the path from the leaf node to the root node + auto path_to_leaf = traverse_towards_leaf_node(tree, nid, parent_id); + // If use_vector_leaf=True: + // * Duplicate the path segments N times, where N = num_groups + // * Insert the duplicated path segments into path_segments + // If use_vector_leaf=False: + // * Insert the path segments into path_segments + auto path_insertor = [&path_to_leaf, &path_segments]( + auto leaf_value, auto path_idx, int group_id) { + for (auto& e : path_to_leaf) { + e.path_idx = path_idx; + e.v = static_cast(leaf_value); + e.group = group_id; + } + path_segments.insert(path_segments.end(), path_to_leaf.cbegin(), path_to_leaf.cend()); + }; + if (use_vector_leaf) { + auto leaf_vector = tree.LeafVector(nid); + if (leaf_vector.size() != static_cast(num_groups)) { + RAFT_FAIL("Expected leaf vector of length %d but got %d instead", + num_groups, + static_cast(leaf_vector.size())); + } + for (int group_id = 0; group_id < num_groups; ++group_id) { + path_insertor(leaf_vector[group_id], path_idx, group_id); + path_idx++; + } + } else { + auto leaf_value = tree.LeafValue(nid); + int group_id = static_cast(tree_idx) % num_groups; + path_insertor(leaf_value, path_idx, group_id); + path_idx++; + } + } + } + return path_segments; +} + template std::unique_ptr extract_path_info_impl( const tl::ModelImpl& model) @@ -178,81 +305,30 @@ std::unique_ptr extract_path_info_impl( if (!std::is_same::value) { RAFT_FAIL("ThresholdType and LeafType must be identical"); } - if (model.task_type != tl::TaskType::kBinaryClfRegr && - model.task_type != tl::TaskType::kMultiClfGrovePerClass) { - RAFT_FAIL("cuML RF / scikit-learn classifiers are not yet supported"); + if (!std::is_same::value && !std::is_same::value) { + RAFT_FAIL("ThresholdType must be either float32 or float64"); } + std::unique_ptr path_info_ptr = std::make_unique>(); auto* path_info = dynamic_cast*>(path_info_ptr.get()); - std::size_t path_idx = 0; - int tree_idx = 0; - int num_groups = 1; - if (model.task_type == tl::TaskType::kMultiClfGrovePerClass && model.task_param.num_class > 1) { - num_groups = model.task_param.num_class; + int num_groups = 1; + bool use_vector_leaf; + if (model.task_param.num_class > 1) { num_groups = model.task_param.num_class; } + if (model.task_type == tl::TaskType::kBinaryClfRegr || + model.task_type == tl::TaskType::kMultiClfGrovePerClass) { + use_vector_leaf = false; + } else if (model.task_type == tl::TaskType::kMultiClfProbDistLeaf) { + use_vector_leaf = true; + } else { + RAFT_FAIL("Unsupported task_type: %d", static_cast(model.task_type)); } - for (const tl::Tree& tree : model.trees) { - int group_id = tree_idx % num_groups; - std::vector parent_id(tree.num_nodes, -1); - // Compute parent ID of each node - for (int i = 0; i < tree.num_nodes; i++) { - if (!tree.IsLeaf(i)) { - parent_id[tree.LeftChild(i)] = i; - parent_id[tree.RightChild(i)] = i; - } - } - - // Find leaf nodes - // Work backwards from leaf to root, order does not matter - // It's also possible to work from root to leaf - for (int i = 0; i < tree.num_nodes; i++) { - if (tree.IsLeaf(i)) { - auto v = static_cast(tree.LeafValue(i)); - int child_idx = i; - int parent_idx = parent_id[child_idx]; - constexpr auto inf = std::numeric_limits::infinity(); - tl::Operator comparison_op = tl::Operator::kNone; - while (parent_idx != -1) { - double zero_fraction = 1.0; - bool has_count_info = false; - if (tree.HasSumHess(parent_idx) && tree.HasSumHess(child_idx)) { - zero_fraction = static_cast(tree.SumHess(child_idx) / tree.SumHess(parent_idx)); - has_count_info = true; - } - if (tree.HasDataCount(parent_idx) && tree.HasDataCount(child_idx)) { - zero_fraction = - static_cast(tree.DataCount(child_idx)) / tree.DataCount(parent_idx); - has_count_info = true; - } - if (!has_count_info) { RAFT_FAIL("Tree model doesn't have data count information"); } - // Encode the range of feature values that flow down this path - bool is_left_path = tree.LeftChild(parent_idx) == child_idx; - if (tree.SplitType(parent_idx) == tl::SplitFeatureType::kCategorical) { - RAFT_FAIL( - "Only trees with numerical splits are supported. " - "Trees with categorical splits are not supported yet."); - } - ThresholdType lower_bound = is_left_path ? -inf : tree.Threshold(parent_idx); - ThresholdType upper_bound = is_left_path ? tree.Threshold(parent_idx) : inf; - comparison_op = tree.ComparisonOp(parent_idx); - path_info->paths.push_back(gpu_treeshap::PathElement>{ - path_idx, - tree.SplitIndex(parent_idx), - group_id, - SplitCondition{lower_bound, upper_bound, comparison_op}, - zero_fraction, - v}); - child_idx = parent_idx; - parent_idx = parent_id[child_idx]; - } - // Root node has feature -1 - comparison_op = tree.ComparisonOp(child_idx); - path_info->paths.push_back(gpu_treeshap::PathElement>{ - path_idx, -1, group_id, SplitCondition{-inf, inf, comparison_op}, 1.0, v}); - path_idx++; - } - } - tree_idx++; + std::size_t path_idx = 0; + for (std::size_t tree_idx = 0; tree_idx < model.trees.size(); ++tree_idx) { + auto path_segments = + extract_path_segments_from_tree(model.trees, tree_idx, use_vector_leaf, num_groups, path_idx); + path_info->paths.insert(path_info->paths.end(), path_segments.cbegin(), path_segments.cend()); + if (!path_segments.empty()) { path_idx = path_segments.back().path_idx + 1; } } path_info->global_bias = model.param.global_bias; path_info->task_type = model.task_type; diff --git a/python/cuml/experimental/explainer/tree_shap.pyx b/python/cuml/experimental/explainer/tree_shap.pyx index f42d26585a..8e0d8b29dd 100644 --- a/python/cuml/experimental/explainer/tree_shap.pyx +++ b/python/cuml/experimental/explainer/tree_shap.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ from cuml.common import input_to_cuml_array from cuml.common.array import CumlArray +from cuml.common.import_utils import has_sklearn from cuml.common.input_utils import determine_array_type from cuml.common.exceptions import NotFittedError from cuml.fil.fil import TreeliteModel @@ -28,6 +29,15 @@ from libcpp.utility cimport move import numpy as np import treelite +if has_sklearn(): + from sklearn.ensemble import RandomForestRegressor as sklrfr + from sklearn.ensemble import RandomForestClassifier as sklrfc +else: + class PlaceHolder: + pass + sklrfr = PlaceHolder + sklrfc = PlaceHolder + cdef extern from "treelite/c_api.h": ctypedef void* ModelHandle cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out) @@ -105,16 +115,17 @@ class TreeExplainer: model = treelite.Model.from_xgboost(model) handle = model.handle.value # cuML RF model object - elif isinstance(model, curfr): + elif isinstance(model, (curfr, curfc)): try: model = model.convert_to_treelite_model() except NotFittedError as e: raise NotFittedError( 'Cannot compute SHAP for un-fitted model') from e handle = model.handle - elif isinstance(model, curfc): - raise NotImplementedError( - 'cuML RF classifiers are not supported yet') + # scikit-learn RF model object + elif isinstance(model, (sklrfr, sklrfc)): + model = treelite.sklearn.import_model(model) + handle = model.handle.value elif isinstance(model, treelite.Model): handle = model.handle.value elif isinstance(model, TreeliteModel): diff --git a/python/cuml/test/explainer/test_gpu_treeshap.py b/python/cuml/test/explainer/test_gpu_treeshap.py index acbd49ed00..e984ba76d5 100644 --- a/python/cuml/test/explainer/test_gpu_treeshap.py +++ b/python/cuml/test/explainer/test_gpu_treeshap.py @@ -20,16 +20,19 @@ import cupy as cp import cudf from cuml.experimental.explainer.tree_shap import TreeExplainer -from cuml.common.import_utils import has_xgboost, has_shap +from cuml.common.import_utils import has_xgboost, has_shap, has_sklearn from cuml.common.exceptions import NotFittedError from cuml.ensemble import RandomForestRegressor as curfr from cuml.ensemble import RandomForestClassifier as curfc -from sklearn.datasets import make_regression, make_classification if has_xgboost(): import xgboost as xgb if has_shap(): import shap +if has_sklearn(): + from sklearn.datasets import make_regression, make_classification + from sklearn.ensemble import RandomForestRegressor as sklrfr + from sklearn.ensemble import RandomForestClassifier as sklrfc @pytest.mark.parametrize('objective', ['reg:linear', 'reg:squarederror', @@ -37,6 +40,7 @@ 'reg:pseudohubererror']) @pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") @pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_xgb_regressor(objective): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, @@ -57,7 +61,7 @@ def test_xgb_regressor(objective): explainer = TreeExplainer(model=tl_model) out = explainer.shap_values(X) - ref_explainer = shap.TreeExplainer(model=xgb_model) + ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, @@ -80,6 +84,7 @@ def test_xgb_regressor(objective): 'multi:softmax', 'multi:softprob']) @pytest.mark.skipif(not has_xgboost(), reason="need to install xgboost") @pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_xgb_classifier(objective, n_classes): n_samples = 100 X, y = make_classification(n_samples=n_samples, n_features=8, @@ -100,7 +105,7 @@ def test_xgb_classifier(objective, n_classes): explainer = TreeExplainer(model=xgb_model) out = explainer.shap_values(X) - ref_explainer = shap.TreeExplainer(model=xgb_model) + ref_explainer = shap.explainers.Tree(model=xgb_model) correct_out = ref_explainer.shap_values(X) np.testing.assert_almost_equal(out, correct_out, decimal=5) np.testing.assert_almost_equal(explainer.expected_value, @@ -132,6 +137,7 @@ def test_degenerate_cases(): @pytest.mark.parametrize('input_type', ['numpy', 'cupy', 'cudf']) +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") def test_cuml_rf_regressor(input_type): n_samples = 100 X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, @@ -150,31 +156,105 @@ def test_cuml_rf_regressor(input_type): explainer = TreeExplainer(model=cuml_model) out = explainer.shap_values(X) - # SHAP values should add up to predicted score - shap_sum = np.sum(out, axis=1) + explainer.expected_value if input_type == 'cupy': pred = pred.get() - shap_sum = shap_sum.get() + out = out.get() + expected_value = explainer.expected_value.get() elif input_type == 'cudf': pred = pred.to_numpy() - shap_sum = shap_sum.get() + out = out.get() + expected_value = explainer.expected_value.get() + else: + expected_value = explainer.expected_value + # SHAP values should add up to predicted score + shap_sum = np.sum(out, axis=1) + expected_value np.testing.assert_almost_equal(shap_sum, pred, decimal=4) +@pytest.mark.parametrize('input_type', ['numpy', 'cupy', 'cudf']) @pytest.mark.parametrize('n_classes', [2, 5]) -def test_cuml_rf_classifier(n_classes): +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_cuml_rf_classifier(n_classes, input_type): n_samples = 100 X, y = make_classification(n_samples=n_samples, n_features=8, n_informative=8, n_redundant=0, n_repeated=0, n_classes=n_classes, random_state=2021) X, y = X.astype(np.float32), y.astype(np.float32) + if input_type == 'cupy': + X, y = cp.array(X), cp.array(y) + elif input_type == 'cudf': + X, y = cudf.DataFrame(X), cudf.Series(y) cuml_model = curfc(max_features=1.0, max_samples=0.1, n_bins=128, min_samples_leaf=2, random_state=123, n_streams=1, n_estimators=10, max_leaves=-1, max_depth=16, accuracy_metric="mse") cuml_model.fit(X, y) + pred = cuml_model.predict_proba(X) + + explainer = TreeExplainer(model=cuml_model) + out = explainer.shap_values(X) + if input_type == 'cupy': + pred = pred.get() + out = out.get() + expected_value = explainer.expected_value.get() + elif input_type == 'cudf': + pred = pred.to_numpy() + out = out.get() + expected_value = explainer.expected_value.get() + else: + expected_value = explainer.expected_value + # SHAP values should add up to predicted score + expected_value = expected_value.reshape(-1, 1) + shap_sum = np.sum(out, axis=2) + np.tile(expected_value, (1, n_samples)) + pred = np.transpose(pred, (1, 0)) + np.testing.assert_almost_equal(shap_sum, pred, decimal=4) + + +@pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_sklearn_rf_regressor(): + n_samples = 100 + X, y = make_regression(n_samples=n_samples, n_features=8, n_informative=8, + n_targets=1, random_state=2021) + X, y = X.astype(np.float32), y.astype(np.float32) + skl_model = sklrfr(max_features=1.0, max_samples=0.1, + min_samples_leaf=2, random_state=123, + n_estimators=10, max_depth=16) + skl_model.fit(X, y) + + explainer = TreeExplainer(model=skl_model) + out = explainer.shap_values(X) + + ref_explainer = shap.explainers.Tree(model=skl_model) + correct_out = ref_explainer.shap_values(X) + np.testing.assert_almost_equal(out, correct_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + ref_explainer.expected_value, decimal=5) + + +@pytest.mark.parametrize('n_classes', [2, 3, 5]) +@pytest.mark.skipif(not has_shap(), reason="need to install shap") +@pytest.mark.skipif(not has_sklearn(), reason="need to install scikit-learn") +def test_sklearn_rf_classifier(n_classes): + n_samples = 100 + X, y = make_classification(n_samples=n_samples, n_features=8, + n_informative=8, n_redundant=0, n_repeated=0, + n_classes=n_classes, random_state=2021) + X, y = X.astype(np.float32), y.astype(np.float32) + skl_model = sklrfc(max_features=1.0, max_samples=0.1, + min_samples_leaf=2, random_state=123, + n_estimators=10, max_depth=16) + skl_model.fit(X, y) + + explainer = TreeExplainer(model=skl_model) + out = explainer.shap_values(X) - with pytest.raises(RuntimeError): - # cuML RF classifier is not supported yet - explainer = TreeExplainer(model=cuml_model) - explainer.shap_values(X) + ref_explainer = shap.explainers.Tree(model=skl_model) + correct_out = np.array(ref_explainer.shap_values(X)) + expected_value = ref_explainer.expected_value + if n_classes == 2: + correct_out = correct_out[1, :, :] + expected_value = expected_value[1:] + np.testing.assert_almost_equal(out, correct_out, decimal=5) + np.testing.assert_almost_equal(explainer.expected_value, + expected_value, decimal=5)