diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 31bb430f0aed..b6286c238006 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -166,10 +166,11 @@ class LIGHTGBM_EXPORT Boosting { * \brief Feature contributions for the model's prediction of one record * \param feature_values Feature value on this record * \param output Prediction result for this record - * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated. */ - virtual void PredictContrib(const double* features, double* output, - const PredictionEarlyStopInstance* early_stop) const = 0; + virtual void PredictContrib(const double* features, double* output) const = 0; + + virtual void PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const = 0; /*! * \brief Dump model to json format string diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 9d7c6e61dd27..ac2164a93775 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -33,6 +33,9 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */ #define C_API_PREDICT_LEAF_INDEX (2) /*!< \brief Predict leaf index. */ #define C_API_PREDICT_CONTRIB (3) /*!< \brief Predict feature contributions (SHAP values). */ +#define C_API_MATRIX_TYPE_CSR (0) /*!< \brief CSR sparse matrix type. */ +#define C_API_MATRIX_TYPE_CSC (1) /*!< \brief CSC sparse matrix type. */ + /*! * \brief Get string message of the last error. * \return Error information @@ -742,6 +745,62 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, int64_t* out_len, double* out_result); +/*! + * \brief Make sparse prediction for a new dataset in CSR or CSC format. Currently only used for feature contributions. + * \note + * The outputs are pre-allocated, as they can vary for each invocation, but the shape should be the same: + * - for feature contributions, the shape of sparse matrix will be ``num_class * num_data * (num_feature + 1)``. + * The output indptr_type for the sparse matrix will be the same as the given input indptr_type. + * Call ``LGBM_BoosterFreePredictSparse`` to deallocate resources. + * \param handle Handle of booster + * \param indptr Pointer to row headers for CSR or column headers for CSC + * \param indptr_type Type of ``indptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64`` + * \param indices Pointer to column indices for CSR or row indices for CSC + * \param data Pointer to the data space + * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` + * \param nindptr Number of rows in the matrix + 1 + * \param nelem Number of nonzero elements in the matrix + * \param num_col_or_row Number of columns for CSR or number of rows for CSC + * \param predict_type What should be predicted, only feature contributions supported currently + * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) + * \param num_iteration Number of iterations for prediction, <= 0 means no limit + * \param parameter Other parameters for prediction, e.g. early stopping for prediction + * \param matrix_type Type of matrix input and output, can be ``C_API_MATRIX_TYPE_CSR`` or ``C_API_MATRIX_TYPE_CSC`` + * \param[out] out_len Length of output indices and data + * \param[out] out_indptr Pointer to output row headers for CSR or column headers for CSC + * \param[out] out_indices Pointer to sparse column indices for CSR or row indices for CSC + * \param[out] out_data Pointer to sparse data space + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, + const void* indptr, + int indptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t nindptr, + int64_t nelem, + int64_t num_col_or_row, + int predict_type, + int num_iteration, + const char* parameter, + int matrix_type, + int64_t* out_len, + void** out_indptr, + int32_t** out_indices, + void** out_data); + +/*! + * \brief Method corresponding to ``LGBM_BoosterPredictSparseOutput`` to free the allocated data. + * \param indptr Pointer to output row headers or column headers to be deallocated + * \param indices Pointer to sparse indices to be deallocated + * \param data Pointer to sparse data space to be deallocated + * \param indptr_type Type of ``indptr``, can be ``C_API_DTYPE_INT32`` or ``C_API_DTYPE_INT64`` + * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data, int indptr_type, int data_type); + /*! * \brief Make prediction for a new dataset in CSR format. This method re-uses the internal predictor structure * from previous calls and is optimized for single row invocation. diff --git a/include/LightGBM/meta.h b/include/LightGBM/meta.h index b15b1ba4b378..3452f28d8ebc 100644 --- a/include/LightGBM/meta.h +++ b/include/LightGBM/meta.h @@ -5,10 +5,11 @@ #ifndef LIGHTGBM_META_H_ #define LIGHTGBM_META_H_ -#include #include #include +#include #include +#include #include #include @@ -58,6 +59,9 @@ typedef int32_t comm_size_t; using PredictFunction = std::function>&, double* output)>; +using PredictSparseFunction = +std::function>&, std::vector>* output)>; + typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size); diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 5ce3ff9b3eb1..b8e4800164f9 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -136,6 +136,8 @@ class Tree { inline int PredictLeafIndexByMap(const std::unordered_map& feature_values) const; inline void PredictContrib(const double* feature_values, int num_features, double* output); + inline void PredictContribByMap(const std::unordered_map& feature_values, + int num_features, std::unordered_map* output); /*! \brief Get Number of leaves*/ inline int num_leaves() const { return num_leaves_; } @@ -387,6 +389,12 @@ class Tree { PathElement *parent_unique_path, double parent_zero_fraction, double parent_one_fraction, int parent_feature_index) const; + void TreeSHAPByMap(const std::unordered_map& feature_values, + std::unordered_map* phi, + int node, int unique_depth, + PathElement *parent_unique_path, double parent_zero_fraction, + double parent_one_fraction, int parent_feature_index) const; + /*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/ static void ExtendPath(PathElement *unique_path, int unique_depth, double zero_fraction, double one_fraction, int feature_index); @@ -539,6 +547,18 @@ inline void Tree::PredictContrib(const double* feature_values, int num_features, } } +inline void Tree::PredictContribByMap(const std::unordered_map& feature_values, + int num_features, std::unordered_map* output) { + (*output)[num_features] += ExpectedValue(); + // Run the recursion with preallocated space for the unique path data + if (num_leaves_ > 1) { + CHECK_GE(max_depth_, 0); + const int max_path_len = max_depth_ + 1; + std::vector unique_path_data(max_path_len*(max_path_len + 1) / 2); + TreeSHAPByMap(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1); + } +} + inline void Tree::RecomputeLeafDepths(int node, int depth) { if (node == 0) leaf_depth_.resize(num_leaves()); if (node < 0) { diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 01a5f31e51b6..9d5be6501b60 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -115,7 +115,15 @@ def cint32_array_to_numpy(cptr, length): if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): return np.fromiter(cptr, dtype=np.int32, count=length) else: - raise RuntimeError('Expected int pointer') + raise RuntimeError('Expected int32 pointer') + + +def cint64_array_to_numpy(cptr, length): + """Convert a ctypes int pointer array to a numpy array.""" + if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)): + return np.fromiter(cptr, dtype=np.int64, count=length) + else: + raise RuntimeError('Expected int64 pointer') def c_str(string): @@ -272,6 +280,10 @@ def get(cls, *args): C_API_PREDICT_LEAF_INDEX = 2 C_API_PREDICT_CONTRIB = 3 +"""Macro definition of sparse matrix type""" +C_API_MATRIX_TYPE_CSR = 0 +C_API_MATRIX_TYPE_CSC = 1 + """Data type of data field""" FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32, "weight": C_API_DTYPE_FLOAT32, @@ -525,8 +537,9 @@ def predict(self, data, num_iteration=-1, Returns ------- - result : numpy array + result : numpy array, scipy.sparse or list of scipy.sparse Prediction result. + Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") @@ -579,7 +592,8 @@ def predict(self, data, num_iteration=-1, preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type) if pred_leaf: preds = preds.astype(np.int32) - if is_reshape and preds.size != nrow: + is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list) + if is_reshape and not is_sparse and preds.size != nrow: if preds.size % nrow == 0: preds = preds.reshape(nrow, -1) else: @@ -651,6 +665,52 @@ def inner_predict(mat, num_iteration, predict_type, preds=None): else: return inner_predict(mat, num_iteration, predict_type) + def __create_sparse_native(self, cs, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + indptr_type, data_type, is_csr=True): + # create numpy array from output arrays + data_indices_len = out_shape[0] + indptr_len = out_shape[1] + if indptr_type == C_API_DTYPE_INT32: + out_indptr = cint32_array_to_numpy(out_ptr_indptr, indptr_len) + elif indptr_type == C_API_DTYPE_INT64: + out_indptr = cint64_array_to_numpy(out_ptr_indptr, indptr_len) + else: + raise TypeError("Expected int32 or int64 type for indptr") + if data_type == C_API_DTYPE_FLOAT32: + out_data = cfloat32_array_to_numpy(out_ptr_data, data_indices_len) + elif data_type == C_API_DTYPE_FLOAT64: + out_data = cfloat64_array_to_numpy(out_ptr_data, data_indices_len) + else: + raise TypeError("Expected float32 or float64 type for data") + out_indices = cint32_array_to_numpy(out_ptr_indices, data_indices_len) + # break up indptr based on number of rows (note more than one matrix in multiclass case) + per_class_indptr_shape = cs.indptr.shape[0] + # for CSC there is extra column added + if not is_csr: + per_class_indptr_shape += 1 + out_indptr_arrays = np.split(out_indptr, out_indptr.shape[0] / per_class_indptr_shape) + # reformat output into a csr or csc matrix or list of csr or csc matrices + cs_output_matrices = [] + offset = 0 + for cs_indptr in out_indptr_arrays: + matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1] + cs_indices = out_indices[offset + cs_indptr[0]:offset + matrix_indptr_len] + cs_data = out_data[offset + cs_indptr[0]:offset + matrix_indptr_len] + offset += matrix_indptr_len + # same shape as input csr or csc matrix except extra column for expected value + cs_shape = [cs.shape[0], cs.shape[1] + 1] + # note: make sure we copy data as it will be deallocated next + if is_csr: + cs_output_matrices.append(scipy.sparse.csr_matrix((cs_data, cs_indices, cs_indptr), cs_shape)) + else: + cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape)) + # free the temporary native indptr, indices, and data + _safe_call(_LIB.LGBM_BoosterFreePredictSparse(out_ptr_indptr, out_ptr_indices, out_ptr_data, + ctypes.c_int(indptr_type), ctypes.c_int(data_type))) + if len(cs_output_matrices) == 1: + return cs_output_matrices[0] + return cs_output_matrices + def __pred_for_csr(self, csr, num_iteration, predict_type): """Predict for a CSR data.""" def inner_predict(csr, num_iteration, predict_type, preds=None): @@ -666,13 +726,13 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): ptr_data, type_ptr_data, _ = c_float_array(csr.data) assert csr.shape[1] <= MAX_INT32 - csr.indices = csr.indices.astype(np.int32, copy=False) + csr_indices = csr.indices.astype(np.int32, copy=False) _safe_call(_LIB.LGBM_BoosterPredictForCSR( self.handle, ptr_indptr, ctypes.c_int32(type_ptr_indptr), - csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, ctypes.c_int(type_ptr_data), ctypes.c_int64(len(csr.indptr)), @@ -687,6 +747,46 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): raise ValueError("Wrong length for predict results") return preds, nrow + def inner_predict_sparse(csr, num_iteration, predict_type): + ptr_indptr, type_ptr_indptr, __ = c_int_array(csr.indptr) + ptr_data, type_ptr_data, _ = c_float_array(csr.data) + csr_indices = csr.indices.astype(np.int32, copy=False) + matrix_type = C_API_MATRIX_TYPE_CSR + if type_ptr_indptr == C_API_DTYPE_INT32: + out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() + else: + out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() + out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() + if type_ptr_data == C_API_DTYPE_FLOAT32: + out_ptr_data = ctypes.POINTER(ctypes.c_float)() + else: + out_ptr_data = ctypes.POINTER(ctypes.c_double)() + out_shape = np.zeros(2, dtype=np.int64) + _safe_call(_LIB.LGBM_BoosterPredictSparseOutput( + self.handle, + ptr_indptr, + ctypes.c_int32(type_ptr_indptr), + csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int64(len(csr.indptr)), + ctypes.c_int64(len(csr.data)), + ctypes.c_int64(csr.shape[1]), + ctypes.c_int(predict_type), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + ctypes.c_int(matrix_type), + out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), + ctypes.byref(out_ptr_indptr), + ctypes.byref(out_ptr_indices), + ctypes.byref(out_ptr_data))) + matrices = self.__create_sparse_native(csr, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + type_ptr_indptr, type_ptr_data, is_csr=True) + nrow = len(csr.indptr) - 1 + return matrices, nrow + + if predict_type == C_API_PREDICT_CONTRIB: + return inner_predict_sparse(csr, num_iteration, predict_type) nrow = len(csr.indptr) - 1 if nrow > MAX_INT32: sections = [0] + list(np.arange(start=MAX_INT32, stop=nrow, step=MAX_INT32)) + [nrow] @@ -704,9 +804,49 @@ def inner_predict(csr, num_iteration, predict_type, preds=None): def __pred_for_csc(self, csc, num_iteration, predict_type): """Predict for a CSC data.""" + def inner_predict_sparse(csc, num_iteration, predict_type): + ptr_indptr, type_ptr_indptr, __ = c_int_array(csc.indptr) + ptr_data, type_ptr_data, _ = c_float_array(csc.data) + csc_indices = csc.indices.astype(np.int32, copy=False) + matrix_type = C_API_MATRIX_TYPE_CSC + if type_ptr_indptr == C_API_DTYPE_INT32: + out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() + else: + out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() + out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() + if type_ptr_data == C_API_DTYPE_FLOAT32: + out_ptr_data = ctypes.POINTER(ctypes.c_float)() + else: + out_ptr_data = ctypes.POINTER(ctypes.c_double)() + out_shape = np.zeros(2, dtype=np.int64) + _safe_call(_LIB.LGBM_BoosterPredictSparseOutput( + self.handle, + ptr_indptr, + ctypes.c_int32(type_ptr_indptr), + csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int64(len(csc.indptr)), + ctypes.c_int64(len(csc.data)), + ctypes.c_int64(csc.shape[0]), + ctypes.c_int(predict_type), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + ctypes.c_int(matrix_type), + out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), + ctypes.byref(out_ptr_indptr), + ctypes.byref(out_ptr_indices), + ctypes.byref(out_ptr_data))) + matrices = self.__create_sparse_native(csc, out_shape, out_ptr_indptr, out_ptr_indices, out_ptr_data, + type_ptr_indptr, type_ptr_data, is_csr=False) + nrow = csc.shape[0] + return matrices, nrow + nrow = csc.shape[0] if nrow > MAX_INT32: return self.__pred_for_csr(csc.tocsr(), num_iteration, predict_type) + if predict_type == C_API_PREDICT_CONTRIB: + return inner_predict_sparse(csc, num_iteration, predict_type) n_preds = self.__get_num_preds(num_iteration, nrow, predict_type) preds = np.zeros(n_preds, dtype=np.float64) out_num_preds = ctypes.c_int64(0) @@ -715,13 +855,13 @@ def __pred_for_csc(self, csc, num_iteration, predict_type): ptr_data, type_ptr_data, _ = c_float_array(csc.data) assert csc.shape[0] <= MAX_INT32 - csc.indices = csc.indices.astype(np.int32, copy=False) + csc_indices = csc.indices.astype(np.int32, copy=False) _safe_call(_LIB.LGBM_BoosterPredictForCSC( self.handle, ptr_indptr, ctypes.c_int32(type_ptr_indptr), - csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, ctypes.c_int(type_ptr_data), ctypes.c_int64(len(csc.indptr)), @@ -1074,12 +1214,12 @@ def __init_from_csr(self, csr, params_str, ref_dataset): ptr_data, type_ptr_data, _ = c_float_array(csr.data) assert csr.shape[1] <= MAX_INT32 - csr.indices = csr.indices.astype(np.int32, copy=False) + csr_indices = csr.indices.astype(np.int32, copy=False) _safe_call(_LIB.LGBM_DatasetCreateFromCSR( ptr_indptr, ctypes.c_int(type_ptr_indptr), - csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, ctypes.c_int(type_ptr_data), ctypes.c_int64(len(csr.indptr)), @@ -1100,12 +1240,12 @@ def __init_from_csc(self, csc, params_str, ref_dataset): ptr_data, type_ptr_data, _ = c_float_array(csc.data) assert csc.shape[0] <= MAX_INT32 - csc.indices = csc.indices.astype(np.int32, copy=False) + csc_indices = csc.indices.astype(np.int32, copy=False) _safe_call(_LIB.LGBM_DatasetCreateFromCSC( ptr_indptr, ctypes.c_int(type_ptr_indptr), - csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, ctypes.c_int(type_ptr_data), ctypes.c_int64(len(csc.indptr)), @@ -2677,8 +2817,9 @@ def predict(self, data, num_iteration=None, Returns ------- - result : numpy array + result : numpy array, scipy.sparse or list of scipy.sparse Prediction result. + Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ predictor = self._to_predictor(copy.deepcopy(kwargs)) if num_iteration is None: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 24cd96471220..d0a61f526273 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -648,7 +648,7 @@ def predict(self, X, raw_score=False, num_iteration=None, The predicted values. X_leaves : array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] + X_SHAP_values : array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects If ``pred_contrib=True``, the feature contributions for each sample. """ if self._n_features is None: @@ -873,7 +873,7 @@ def predict_proba(self, X, raw_score=False, num_iteration=None, The predicted probability for each class for each sample. X_leaves : array-like of shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes] + X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects If ``pred_contrib=True``, the feature contributions for each sample. """ result = super(LGBMClassifier, self).predict(X, raw_score, num_iteration, diff --git a/src/application/predictor.hpp b/src/application/predictor.hpp index 1c56cfa5eb2c..48ef227de2c6 100644 --- a/src/application/predictor.hpp +++ b/src/application/predictor.hpp @@ -88,12 +88,18 @@ class Predictor { double* output) { int tid = omp_get_thread_num(); CopyToPredictBuffer(predict_buf_[tid].data(), features); - // get result for leaf index - boosting_->PredictContrib(predict_buf_[tid].data(), output, - &early_stop_); + // get feature importances + boosting_->PredictContrib(predict_buf_[tid].data(), output); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); }; + predict_sparse_fun_ = [=](const std::vector>& features, + std::vector>* output) { + auto buf = CopyToPredictMap(features); + // get sparse feature importances + boosting_->PredictContribByMap(buf, output); + }; + } else { if (is_raw_score) { predict_fun_ = [=](const std::vector>& features, @@ -140,6 +146,11 @@ class Predictor { return predict_fun_; } + + inline const PredictSparseFunction& GetPredictSparseFunction() const { + return predict_sparse_fun_; + } + /*! * \brief predicting on data, then saving result to disk * \param data_filename Filename of data @@ -275,6 +286,7 @@ class Predictor { const Boosting* boosting_; /*! \brief function for prediction */ PredictFunction predict_fun_; + PredictSparseFunction predict_sparse_fun_; PredictionEarlyStopInstance early_stop_; int num_feature_; int num_pred_one_row_; diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 7871bbfb086c..491e7966aff5 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -570,8 +570,7 @@ const double* GBDT::GetTrainingScore(int64_t* out_len) { return train_score_updater_->score(); } -void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const { - int early_stop_round_counter = 0; +void GBDT::PredictContrib(const double* features, double* output) const { // set zero const int num_features = max_feature_idx_ + 1; std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features + 1)); @@ -580,13 +579,16 @@ void GBDT::PredictContrib(const double* features, double* output, const Predicti for (int k = 0; k < num_tree_per_iteration_; ++k) { models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features + 1)); } - // check early stopping - ++early_stop_round_counter; - if (early_stop->round_period == early_stop_round_counter) { - if (early_stop->callback_function(output, num_tree_per_iteration_)) { - return; - } - early_stop_round_counter = 0; + } +} + +void GBDT::PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const { + const int num_features = max_feature_idx_ + 1; + for (int i = 0; i < num_iteration_for_pred_; ++i) { + // predict all the trees for one iteration + for (int k = 0; k < num_tree_per_iteration_; ++k) { + models_[i * num_tree_per_iteration_ + k]->PredictContribByMap(features, num_features, &((*output)[k])); } } } diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index f3ece67fec0b..0e1d014636f4 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -210,18 +210,18 @@ class GBDT : public GBDTBase { * \return number of prediction */ inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override { - int num_preb_in_one_row = num_class_; + int num_pred_in_one_row = num_class_; if (is_pred_leaf) { int max_iteration = GetCurrentIteration(); if (num_iteration > 0) { - num_preb_in_one_row *= static_cast(std::min(max_iteration, num_iteration)); + num_pred_in_one_row *= static_cast(std::min(max_iteration, num_iteration)); } else { - num_preb_in_one_row *= max_iteration; + num_pred_in_one_row *= max_iteration; } } else if (is_pred_contrib) { - num_preb_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline + num_pred_in_one_row = num_tree_per_iteration_ * (max_feature_idx_ + 2); // +1 for 0-based indexing, +1 for baseline } - return num_preb_in_one_row; + return num_pred_in_one_row; } void PredictRaw(const double* features, double* output, @@ -240,8 +240,10 @@ class GBDT : public GBDTBase { void PredictLeafIndexByMap(const std::unordered_map& features, double* output) const override; - void PredictContrib(const double* features, double* output, - const PredictionEarlyStopInstance* earlyStop) const override; + void PredictContrib(const double* features, double* output) const override; + + void PredictContribByMap(const std::unordered_map& features, + std::vector>* output) const override; /*! * \brief Dump model to json format string diff --git a/src/c_api.cpp b/src/c_api.cpp index 290f219fa639..ec77673b28a8 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -382,16 +382,11 @@ class Booster { *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row; } - - void Predict(int num_iteration, int predict_type, int nrow, int ncol, - std::function>(int row_idx)> get_row_fun, - const Config& config, - double* out_result, int64_t* out_len) { + Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) { if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) { Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \ "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1); } - std::lock_guard lock(mutex_); bool is_predict_leaf = false; bool is_raw_score = false; bool predict_contrib = false; @@ -407,6 +402,22 @@ class Booster { Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib, config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); + return predictor; + } + + void Predict(int num_iteration, int predict_type, int nrow, int ncol, + std::function>(int row_idx)> get_row_fun, + const Config& config, + double* out_result, int64_t* out_len) { + std::lock_guard lock(mutex_); + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + bool is_predict_leaf = false; + bool predict_contrib = false; + if (predict_type == C_API_PREDICT_LEAF_INDEX) { + is_predict_leaf = true; + } else if (predict_type == C_API_PREDICT_CONTRIB) { + predict_contrib = true; + } int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib); auto pred_fun = predictor.GetPredictFunction(); OMP_INIT_EX(); @@ -422,6 +433,236 @@ class Booster { *out_len = num_pred_in_one_row * nrow; } + void PredictSparse(int num_iteration, int predict_type, int64_t nrow, int ncol, + std::function>(int64_t row_idx)> get_row_fun, + const Config& config, int64_t* out_elements_size, + std::vector>>* agg_ptr, + int32_t** out_indices, void** out_data, int data_type, + bool* is_data_float32_ptr, int num_matrices) { + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + auto pred_sparse_fun = predictor.GetPredictSparseFunction(); + std::vector>>& agg = *agg_ptr; + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int64_t i = 0; i < nrow; ++i) { + OMP_LOOP_EX_BEGIN(); + auto one_row = get_row_fun(i); + agg[i] = std::vector>(num_matrices); + pred_sparse_fun(one_row, &agg[i]); + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + // calculate the nonzero data and indices size + int64_t elements_size = 0; + for (int64_t i = 0; i < static_cast(agg.size()); ++i) { + auto row_vector = agg[i]; + for (int j = 0; j < static_cast(row_vector.size()); ++j) { + elements_size += static_cast(row_vector[j].size()); + } + } + *out_elements_size = elements_size; + *is_data_float32_ptr = false; + // allocate data and indices arrays + if (data_type == C_API_DTYPE_FLOAT32) { + *out_data = new float[elements_size]; + *is_data_float32_ptr = true; + } else if (data_type == C_API_DTYPE_FLOAT64) { + *out_data = new double[elements_size]; + } else { + Log::Fatal("Unknown data type in PredictSparse"); + return; + } + *out_indices = new int32_t[elements_size]; + } + + void PredictSparseCSR(int num_iteration, int predict_type, int64_t nrow, int ncol, + std::function>(int64_t row_idx)> get_row_fun, + const Config& config, + int64_t* out_len, void** out_indptr, int indptr_type, + int32_t** out_indices, void** out_data, int data_type) { + std::lock_guard lock(mutex_); + // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) + int num_matrices = boosting_->NumModelPerIteration(); + bool is_indptr_int32 = false; + bool is_data_float32 = false; + int64_t indptr_size = (nrow + 1) * num_matrices; + if (indptr_type == C_API_DTYPE_INT32) { + *out_indptr = new int32_t[indptr_size]; + is_indptr_int32 = true; + } else if (indptr_type == C_API_DTYPE_INT64) { + *out_indptr = new int64_t[indptr_size]; + } else { + Log::Fatal("Unknown indptr type in PredictSparseCSR"); + return; + } + // aggregated per row feature contribution results + std::vector>> agg(nrow); + int64_t elements_size = 0; + PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, + out_indices, out_data, data_type, &is_data_float32, num_matrices); + std::vector row_sizes(num_matrices * nrow); + std::vector row_matrix_offsets(num_matrices * nrow); + int64_t row_vector_cnt = 0; + for (int m = 0; m < num_matrices; ++m) { + for (int64_t i = 0; i < static_cast(agg.size()); ++i) { + auto row_vector = agg[i]; + auto row_vector_size = row_vector[m].size(); + // keep track of the row_vector sizes for parallelization + row_sizes[row_vector_cnt] = static_cast(row_vector_size); + if (i == 0) { + row_matrix_offsets[row_vector_cnt] = 0; + } else { + row_matrix_offsets[row_vector_cnt] = static_cast(row_sizes[row_vector_cnt - 1] + row_matrix_offsets[row_vector_cnt - 1]); + } + row_vector_cnt++; + } + } + // copy vector results to output for each row + int64_t indptr_index = 0; + for (int m = 0; m < num_matrices; ++m) { + if (is_indptr_int32) { + (reinterpret_cast(*out_indptr))[indptr_index] = 0; + } else { + (reinterpret_cast(*out_indptr))[indptr_index] = 0; + } + indptr_index++; + int64_t matrix_start_index = m * static_cast(agg.size()); + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int64_t i = 0; i < static_cast(agg.size()); ++i) { + OMP_LOOP_EX_BEGIN(); + auto row_vector = agg[i]; + int64_t row_start_index = matrix_start_index + i; + int64_t element_index = row_matrix_offsets[row_start_index]; + int64_t indptr_loop_index = indptr_index + i; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + (*out_indices)[element_index] = it->first; + if (is_data_float32) { + (reinterpret_cast(*out_data))[element_index] = static_cast(it->second); + } else { + (reinterpret_cast(*out_data))[element_index] = it->second; + } + element_index++; + } + int64_t indptr_value = row_matrix_offsets[row_start_index] + row_sizes[row_start_index]; + if (is_indptr_int32) { + (reinterpret_cast(*out_indptr))[indptr_loop_index] = static_cast(indptr_value); + } else { + (reinterpret_cast(*out_indptr))[indptr_loop_index] = indptr_value; + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + indptr_index += static_cast(agg.size()); + } + out_len[0] = elements_size; + out_len[1] = indptr_size; + } + + void PredictSparseCSC(int num_iteration, int predict_type, int64_t nrow, int ncol, + std::function>(int64_t row_idx)> get_row_fun, + const Config& config, + int64_t* out_len, void** out_col_ptr, int col_ptr_type, + int32_t** out_indices, void** out_data, int data_type) { + std::lock_guard lock(mutex_); + // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices) + int num_matrices = boosting_->NumModelPerIteration(); + auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config); + auto pred_sparse_fun = predictor.GetPredictSparseFunction(); + bool is_col_ptr_int32 = false; + bool is_data_float32 = false; + int num_output_cols = ncol + 1; + int col_ptr_size = (num_output_cols + 1) * num_matrices; + if (col_ptr_type == C_API_DTYPE_INT32) { + *out_col_ptr = new int32_t[col_ptr_size]; + is_col_ptr_int32 = true; + } else if (col_ptr_type == C_API_DTYPE_INT64) { + *out_col_ptr = new int64_t[col_ptr_size]; + } else { + Log::Fatal("Unknown col_ptr type in PredictSparseCSC"); + return; + } + // aggregated per row feature contribution results + std::vector>> agg(nrow); + int64_t elements_size = 0; + PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg, + out_indices, out_data, data_type, &is_data_float32, num_matrices); + // calculate number of elements per column to construct + // the CSC matrix with random access + std::vector> column_sizes(num_matrices); + for (int m = 0; m < num_matrices; ++m) { + column_sizes[m] = std::vector(num_output_cols, 0); + for (int64_t i = 0; i < static_cast(agg.size()); ++i) { + auto row_vector = agg[i]; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + column_sizes[m][it->first] += 1; + } + } + } + // keep track of column counts + std::vector> column_counts(num_matrices); + // keep track of beginning index for each column + std::vector> column_start_indices(num_matrices); + // keep track of beginning index for each matrix + std::vector matrix_start_indices(num_matrices, 0); + int col_ptr_index = 0; + for (int m = 0; m < num_matrices; ++m) { + int64_t col_ptr_value = 0; + column_start_indices[m] = std::vector(num_output_cols, 0); + column_counts[m] = std::vector(num_output_cols, 0); + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = static_cast(col_ptr_value); + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = col_ptr_value; + } + col_ptr_index++; + for (int64_t i = 1; i < static_cast(column_sizes[m].size()); ++i) { + column_start_indices[m][i] = column_sizes[m][i - 1] + column_start_indices[m][i - 1]; + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = static_cast(column_start_indices[m][i]); + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = column_start_indices[m][i]; + } + col_ptr_index++; + } + int64_t last_elem_index = static_cast(column_sizes[m].size()) - 1; + int64_t last_column_start_index = column_start_indices[m][last_elem_index]; + int64_t last_column_size = column_sizes[m][last_elem_index]; + if (is_col_ptr_int32) { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = static_cast(last_column_start_index + last_column_size); + } else { + (reinterpret_cast(*out_col_ptr))[col_ptr_index] = last_column_start_index + last_column_size; + } + if (m != 0) { + matrix_start_indices[m] = matrix_start_indices[m - 1] + + last_column_start_index + + last_column_size; + } + } + for (int m = 0; m < num_matrices; ++m) { + for (int64_t i = 0; i < static_cast(agg.size()); ++i) { + auto row_vector = agg[i]; + for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) { + int64_t col_idx = it->first; + int64_t element_index = column_start_indices[m][col_idx] + + matrix_start_indices[m] + + column_counts[m][col_idx]; + // store the row index + (*out_indices)[element_index] = static_cast(i); + // update column count + column_counts[m][col_idx]++; + if (is_data_float32) { + (reinterpret_cast(*out_data))[element_index] = static_cast(it->second); + } else { + (reinterpret_cast(*out_data))[element_index] = it->second; + } + } + } + } + out_len[0] = elements_size; + out_len[1] = col_ptr_size; + } + void Predict(int num_iteration, int predict_type, const char* data_filename, int data_has_header, const Config& config, const char* result_filename) { @@ -581,7 +822,8 @@ RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int d std::function>(int row_idx)> RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type); -std::function>(int idx)> +template +std::function>(T idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem); @@ -713,7 +955,7 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, int64_t start_row) { API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); - auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); OMP_INIT_EX(); #pragma omp parallel for schedule(static) @@ -860,7 +1102,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, omp_set_num_threads(config.num_threads); } std::unique_ptr ret; - auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); if (reference == nullptr) { // sample data first @@ -1520,13 +1762,98 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, omp_set_num_threads(config.num_threads); } Booster* ref_booster = reinterpret_cast(handle); - auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int nrow = static_cast(nindptr - 1); ref_booster->Predict(num_iteration, predict_type, nrow, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } +int LGBM_BoosterPredictSparseOutput(BoosterHandle handle, + const void* indptr, + int indptr_type, + const int32_t* indices, + const void* data, + int data_type, + int64_t nindptr, + int64_t nelem, + int64_t num_col_or_row, + int predict_type, + int num_iteration, + const char* parameter, + int matrix_type, + int64_t* out_len, + void** out_indptr, + int32_t** out_indices, + void** out_data) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + auto param = Config::Str2Map(parameter); + Config config; + config.Set(param); + if (config.num_threads > 0) { + omp_set_num_threads(config.num_threads); + } + if (matrix_type == C_API_MATRIX_TYPE_CSR) { + if (num_col_or_row <= 0) { + Log::Fatal("The number of columns should be greater than zero."); + } else if (num_col_or_row >= INT32_MAX) { + Log::Fatal("The number of columns should be smaller than INT32_MAX."); + } + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + int64_t nrow = nindptr - 1; + ref_booster->PredictSparseCSR(num_iteration, predict_type, nrow, static_cast(num_col_or_row), get_row_fun, + config, out_len, out_indptr, indptr_type, out_indices, out_data, data_type); + } else if (matrix_type == C_API_MATRIX_TYPE_CSC) { + int num_threads = OMP_NUM_THREADS(); + int ncol = static_cast(nindptr - 1); + std::vector> iterators(num_threads, std::vector()); + for (int i = 0; i < num_threads; ++i) { + for (int j = 0; j < ncol; ++j) { + iterators[i].emplace_back(indptr, indptr_type, indices, data, data_type, nindptr, nelem, j); + } + } + std::function>(int64_t row_idx)> get_row_fun = + [&iterators, ncol](int64_t i) { + std::vector> one_row; + one_row.reserve(ncol); + const int tid = omp_get_thread_num(); + for (int j = 0; j < ncol; ++j) { + auto val = iterators[tid][j].Get(static_cast(i)); + if (std::fabs(val) > kZeroThreshold || std::isnan(val)) { + one_row.emplace_back(j, val); + } + } + return one_row; + }; + ref_booster->PredictSparseCSC(num_iteration, predict_type, num_col_or_row, ncol, get_row_fun, config, + out_len, out_indptr, indptr_type, out_indices, out_data, data_type); + } else { + Log::Fatal("Unknown matrix type in LGBM_BoosterPredictSparseOutput"); + } + API_END(); +} + +int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data, int indptr_type, int data_type) { + API_BEGIN(); + if (indptr_type == C_API_DTYPE_INT32) { + delete reinterpret_cast(indptr); + } else if (indptr_type == C_API_DTYPE_INT64) { + delete reinterpret_cast(indptr); + } else { + Log::Fatal("Unknown indptr type in LGBM_BoosterFreePredictSparse"); + } + delete indices; + if (data_type == C_API_DTYPE_FLOAT32) { + delete reinterpret_cast(data); + } else if (data_type == C_API_DTYPE_FLOAT64) { + delete reinterpret_cast(data); + } else { + Log::Fatal("Unknown data type in LGBM_BoosterFreePredictSparse"); + } + API_END(); +} + int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, const void* indptr, int indptr_type, @@ -1554,7 +1881,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, omp_set_num_threads(config.num_threads); } Booster* ref_booster = reinterpret_cast(handle); - auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); + auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast(num_col), get_row_fun, config, out_result, out_len); API_END(); } @@ -1890,13 +2217,14 @@ RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) { }; } -std::function>(int idx)> +template +std::function>(T idx)> RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) { if (data_type == C_API_DTYPE_FLOAT32) { const float* data_ptr = reinterpret_cast(data); if (indptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_indptr = reinterpret_cast(indptr); - return [=] (int idx) { + return [=] (T idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; @@ -1910,7 +2238,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, }; } else if (indptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_indptr = reinterpret_cast(indptr); - return [=] (int idx) { + return [=] (T idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; @@ -1927,7 +2255,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const double* data_ptr = reinterpret_cast(data); if (indptr_type == C_API_DTYPE_INT32) { const int32_t* ptr_indptr = reinterpret_cast(indptr); - return [=] (int idx) { + return [=] (T idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; @@ -1941,7 +2269,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, }; } else if (indptr_type == C_API_DTYPE_INT64) { const int64_t* ptr_indptr = reinterpret_cast(indptr); - return [=] (int idx) { + return [=] (T idx) { std::vector> ret; int64_t start = ptr_indptr[idx]; int64_t end = ptr_indptr[idx + 1]; diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 63641311787f..8e5104f168eb 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -687,7 +687,9 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, double parent_one_fraction, int parent_feature_index) const { // extend the unique path PathElement* unique_path = parent_unique_path + unique_depth; - if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); + if (unique_depth > 0) { + std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); + } ExtendPath(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction, parent_feature_index); @@ -730,6 +732,58 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, } } +// recursive sparse computation of SHAP values for a decision tree +void Tree::TreeSHAPByMap(const std::unordered_map& feature_values, std::unordered_map* phi, + int node, int unique_depth, + PathElement *parent_unique_path, double parent_zero_fraction, + double parent_one_fraction, int parent_feature_index) const { + // extend the unique path + PathElement* unique_path = parent_unique_path + unique_depth; + if (unique_depth > 0) { + std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); + } + ExtendPath(unique_path, unique_depth, parent_zero_fraction, + parent_one_fraction, parent_feature_index); + + // leaf node + if (node < 0) { + for (int i = 1; i <= unique_depth; ++i) { + const double w = UnwoundPathSum(unique_path, unique_depth, i); + const PathElement &el = unique_path[i]; + (*phi)[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node]; + } + + // internal node + } else { + const int hot_index = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node); + const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]); + const double w = data_count(node); + const double hot_zero_fraction = data_count(hot_index) / w; + const double cold_zero_fraction = data_count(cold_index) / w; + double incoming_zero_fraction = 1; + double incoming_one_fraction = 1; + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + int path_index = 0; + for (; path_index <= unique_depth; ++path_index) { + if (unique_path[path_index].feature_index == split_feature_[node]) break; + } + if (path_index != unique_depth + 1) { + incoming_zero_fraction = unique_path[path_index].zero_fraction; + incoming_one_fraction = unique_path[path_index].one_fraction; + UnwindPath(unique_path, unique_depth, path_index); + unique_depth -= 1; + } + + TreeSHAPByMap(feature_values, phi, hot_index, unique_depth + 1, unique_path, + hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]); + + TreeSHAPByMap(feature_values, phi, cold_index, unique_depth + 1, unique_path, + cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]); + } +} + double Tree::ExpectedValue() const { if (num_leaves_ == 1) return LeafOutput(0); const double total_count = internal_count_[0]; diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b5cefb1e117f..a26bd7a09449 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -9,9 +9,9 @@ import lightgbm as lgb import numpy as np -from scipy.sparse import csr_matrix +from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, - load_iris, load_svmlight_file) + load_iris, load_svmlight_file, make_multilabel_classification) from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold @@ -941,6 +941,60 @@ def test_contribs(self): self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(gbm.predict(X_test, pred_contrib=True), axis=1)), 1e-4) + def test_contribs_sparse(self): + n_features = 20 + n_samples = 100 + # generate CSR sparse dataset + X, y = make_multilabel_classification(n_samples=n_samples, + sparse=True, + n_features=n_features, + n_classes=1, + n_labels=2) + y = y.flatten() + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + params = { + 'objective': 'binary', + 'verbose': -1, + } + lgb_train = lgb.Dataset(X_train, y_train) + gbm = lgb.train(params, lgb_train, num_boost_round=20) + contribs_csr = gbm.predict(X_test, pred_contrib=True) + self.assertTrue(isspmatrix_csr(contribs_csr)) + # convert data to dense and get back same contribs + contribs_dense = gbm.predict(X_test.toarray(), pred_contrib=True) + # validate the values are the same + np.testing.assert_allclose(contribs_csr.toarray(), contribs_dense) + self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) + - np.sum(contribs_dense, axis=1)), 1e-4) + # validate using CSC matrix + X_test_csc = X_test.tocsc() + contribs_csc = gbm.predict(X_test_csc, pred_contrib=True) + self.assertTrue(isspmatrix_csc(contribs_csc)) + # validate the values are the same + np.testing.assert_allclose(contribs_csc.toarray(), contribs_dense) + + @unittest.skipIf(psutil.virtual_memory().available / 1024 / 1024 / 1024 < 3, 'not enough RAM') + def test_int32_max_sparse_contribs(self): + params = { + 'objective': 'binary' + } + train_features = np.random.rand(100, 1000) + train_targets = [0] * 50 + [1] * 50 + lgb_train = lgb.Dataset(train_features, train_targets) + gbm = lgb.train(params, lgb_train, num_boost_round=2) + csr_input_shape = (3000000, 1000) + test_features = csr_matrix(csr_input_shape) + for i in range(0, csr_input_shape[0], csr_input_shape[0] // 6): + for j in range(0, 1000, 100): + test_features[i, j] = random.random() + y_pred_csr = gbm.predict(test_features, pred_contrib=True) + # Note there is an extra column added to the output for the expected value + csr_output_shape = (csr_input_shape[0], csr_input_shape[1] + 1) + self.assertTupleEqual(y_pred_csr.shape, csr_output_shape) + y_pred_csc = gbm.predict(test_features.tocsc(), pred_contrib=True) + # Note output CSC shape should be same as CSR output shape + self.assertTupleEqual(y_pred_csc.shape, csr_output_shape) + def test_sliced_data(self): def train_and_get_predictions(features, labels): dataset = lgb.Dataset(features, label=labels)