Skip to content

Commit

Permalink
TreeExplainer extensions (rapidsai#4697)
Browse files Browse the repository at this point in the history
Stacked on rapidsai#4671.

- Remove extra redundant class in python layer.
- Simplify the interface between C++ and python using variants. 
- Fix rapidsai#4670 by allowing double precision data
- Document TreeExplainer
- Add interventional shap method
- Add shapley interactions and taylor interactions
- Promote from experimental
- Support sklearn estimator types from xgb/lgbm (i.e. no need to convert to booster before using TreeExplainer)

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4697
  • Loading branch information
RAMitchell authored May 11, 2022
1 parent 8ba9147 commit f7ccdca
Show file tree
Hide file tree
Showing 9 changed files with 778 additions and 362 deletions.
55 changes: 40 additions & 15 deletions cpp/include/cuml/explainer/tree_shap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,51 @@
#include <cstdint>
#include <cuml/ensemble/treelite_defs.hpp>
#include <memory>
#include <variant>

namespace ML {
namespace Explainer {

// An abstract class representing an opaque handle to path information
// extracted from a tree model. The implementation in tree_shap.cu will
// define an internal class that inherits from this abtract class.
class TreePathInfo {
public:
enum class ThresholdTypeEnum : std::uint8_t { kFloat, kDouble };
virtual ThresholdTypeEnum GetThresholdType() const = 0;
virtual ~TreePathInfo() {}
};

std::unique_ptr<TreePathInfo> extract_path_info(ModelHandle model);
void gpu_treeshap(TreePathInfo* path_info,
const float* data,
template <typename T>
class TreePathInfo;

using TreePathHandle =
std::variant<std::shared_ptr<TreePathInfo<float>>, std::shared_ptr<TreePathInfo<double>>>;

using FloatPointer = std::variant<float*, double*>;

TreePathHandle extract_path_info(ModelHandle model);

void gpu_treeshap(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
float* out_preds);
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_interventional(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
const FloatPointer background_data,
std::size_t background_n_rows,
std::size_t background_n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

void gpu_treeshap_taylor_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);

} // namespace Explainer
} // namespace ML
} // namespace ML
Loading

0 comments on commit f7ccdca

Please sign in to comment.