diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index f2dd46c6..53a8ff5f 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -181,6 +181,8 @@ TREELITE_DLL int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \param out pointer to store the loaded model @@ -189,8 +191,8 @@ TREELITE_DLL int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - ModelHandle* out); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, ModelHandle* out); /*! * \brief Load a scikit-learn isolation forest model from a collection of arrays. Refer to @@ -211,6 +213,8 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( * only defined if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity not used, but must be passed as array of arrays for each tree and node. * \param ratio_c standardizing constant to use for calculation of the anomaly score. * \param out pointer to store the loaded model @@ -219,8 +223,8 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( TREELITE_DLL int TreeliteLoadSKLearnIsolationForest( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - const double ratio_c, ModelHandle* out); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, const double ratio_c, ModelHandle* out); /*! * \brief Load a scikit-learn random forest classifier model from a collection of arrays. Refer to @@ -243,6 +247,8 @@ TREELITE_DLL int TreeliteLoadSKLearnIsolationForest( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \param out pointer to store the loaded model @@ -252,7 +258,7 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity, ModelHandle* out); + const double** weighted_n_node_samples, const double** impurity, ModelHandle* out); /*! * \brief Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer @@ -273,6 +279,8 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \param out pointer to store the loaded model @@ -281,8 +289,8 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - ModelHandle* out); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, ModelHandle* out); /*! * \brief Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer @@ -304,6 +312,8 @@ TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingRegressor( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \param out pointer to store the loaded model @@ -313,7 +323,7 @@ TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity, ModelHandle* out); + const double** weighted_n_node_samples, const double** impurity, ModelHandle* out); /*! * \brief Query the number of trees in the model diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 12caf5a9..235dcf45 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -77,6 +77,8 @@ std::unique_ptr LoadXGBoostJSONModelString(const char* json_str * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \return loaded model @@ -84,7 +86,8 @@ std::unique_ptr LoadXGBoostJSONModelString(const char* json_str std::unique_ptr LoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity); /*! * \brief Load a scikit-learn isolation forest model from a collection of arrays. Refer to * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to @@ -104,6 +107,8 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( * only defined if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity not used, but must be passed as array of arrays for each tree and node. * \param ratio_c standardizing constant to use for calculation of the anomaly score. * \return loaded model @@ -111,8 +116,8 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( std::unique_ptr LoadSKLearnIsolationForest( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - const double ratio_c); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, const double ratio_c); /*! * \brief Load a scikit-learn random forest classifier model from a collection of arrays. Refer to * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to @@ -134,6 +139,8 @@ std::unique_ptr LoadSKLearnIsolationForest( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \return loaded model @@ -142,7 +149,7 @@ std::unique_ptr LoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity); + const double** weighted_n_node_samples, const double** impurity); /*! * \brief Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to @@ -162,6 +169,8 @@ std::unique_ptr LoadSKLearnRandomForestClassifier( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \return loaded model @@ -169,7 +178,8 @@ std::unique_ptr LoadSKLearnRandomForestClassifier( std::unique_ptr LoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity); + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity); /*! * \brief Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to @@ -190,6 +200,8 @@ std::unique_ptr LoadSKLearnGradientBoostingRegressor( * if node k is a leaf node. * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with * node k of the i-th tree. + * \param weighted_n_node_samples weighted_n_node_samples[i][k] stores the sum of weighted data + * samples associated with node k of the i-th tree. * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with * node k of the i-th tree. * \return loaded model @@ -198,7 +210,7 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity); + const double** weighted_n_node_samples, const double** impurity); //-------------------------------------------------------------------------- // model builder interface: build trees incrementally diff --git a/python/treelite/sklearn/importer.py b/python/treelite/sklearn/importer.py index ff4b3012..05afd6ee 100644 --- a/python/treelite/sklearn/importer.py +++ b/python/treelite/sklearn/importer.py @@ -141,6 +141,7 @@ def import_model(sklearn_model): threshold = ArrayOfArrays(dtype=np.float64) value = ArrayOfArrays(dtype=np.float64) n_node_samples = ArrayOfArrays(dtype=np.int64) + weighted_n_node_samples = ArrayOfArrays(dtype=np.float64) impurity = ArrayOfArrays(dtype=np.float64) for estimator in sklearn_model.estimators_: if isinstance(sklearn_model, (GradientBoostingR, GradientBoostingC)): @@ -170,6 +171,8 @@ def import_model(sklearn_model): value.add(isolation_depths.reshape((-1,1,1)), expected_shape=leaf_value_expected_shape(tree.node_count)) n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) + weighted_n_node_samples.add(tree.weighted_n_node_samples, + expected_shape=(tree.node_count,)) impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() @@ -178,36 +181,36 @@ def import_model(sklearn_model): ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.byref(handle))) + value.as_c_array(), n_node_samples.as_c_array(), weighted_n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) elif isinstance(sklearn_model, IsolationForest): _check_call(_LIB.TreeliteLoadSKLearnIsolationForest( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.c_double(ratio_c), ctypes.byref(handle))) + value.as_c_array(), n_node_samples.as_c_array(), weighted_n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.c_double(ratio_c), ctypes.byref(handle))) elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), - impurity.as_c_array(), ctypes.byref(handle))) + weighted_n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) elif isinstance(sklearn_model, GradientBoostingR): _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.byref(handle))) + value.as_c_array(), n_node_samples.as_c_array(), weighted_n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) elif isinstance(sklearn_model, GradientBoostingC): _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), - impurity.as_c_array(), ctypes.byref(handle))) + weighted_n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) else: raise TreeliteError(f'Unsupported model type {sklearn_model.__class__.__name__}: ' + 'currently random forests, extremely randomized trees, and gradient ' + diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 903e965b..9fc102ec 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -151,12 +151,12 @@ int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelH int TreeliteLoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - ModelHandle* out) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, ModelHandle* out) { API_BEGIN(); std::unique_ptr model = frontend::LoadSKLearnRandomForestRegressor( n_estimators, n_features, node_count, children_left, children_right, feature, threshold, - value, n_node_samples, impurity); + value, n_node_samples, weighted_n_node_samples, impurity); *out = static_cast(model.release()); API_END(); } @@ -164,12 +164,12 @@ int TreeliteLoadSKLearnRandomForestRegressor( int TreeliteLoadSKLearnIsolationForest( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - const double ratio_c, ModelHandle* out) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, const double ratio_c, ModelHandle* out) { API_BEGIN(); std::unique_ptr model = frontend::LoadSKLearnIsolationForest( n_estimators, n_features, node_count, children_left, children_right, feature, threshold, - value, n_node_samples, impurity, ratio_c); + value, n_node_samples, weighted_n_node_samples, impurity, ratio_c); *out = static_cast(model.release()); API_END(); } @@ -178,11 +178,11 @@ int TreeliteLoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity, ModelHandle* out) { + const double** weighted_n_node_samples, const double** impurity, ModelHandle* out) { API_BEGIN(); std::unique_ptr model = frontend::LoadSKLearnRandomForestClassifier( n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, - threshold, value, n_node_samples, impurity); + threshold, value, n_node_samples, weighted_n_node_samples, impurity); *out = static_cast(model.release()); API_END(); } @@ -190,12 +190,12 @@ int TreeliteLoadSKLearnRandomForestClassifier( int TreeliteLoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - ModelHandle* out) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, ModelHandle* out) { API_BEGIN(); std::unique_ptr model = frontend::LoadSKLearnGradientBoostingRegressor( n_estimators, n_features, node_count, children_left, children_right, feature, threshold, - value, n_node_samples, impurity); + value, n_node_samples, weighted_n_node_samples, impurity); *out = static_cast(model.release()); API_END(); } @@ -204,11 +204,11 @@ int TreeliteLoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity, ModelHandle* out) { + const double** weighted_n_node_samples, const double** impurity, ModelHandle* out) { API_BEGIN(); std::unique_ptr model = frontend::LoadSKLearnGradientBoostingClassifier( n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, - threshold, value, n_node_samples, impurity); + threshold, value, n_node_samples, weighted_n_node_samples, impurity); *out = static_cast(model.release()); API_END(); } diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index aca0c9f2..c8fdcfd8 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -21,7 +21,8 @@ std::unique_ptr LoadSKLearnModel( int n_trees, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity, MetaHandlerFunc meta_handler, LeafHandlerFunc leaf_handler) { + const double** weighted_n_node_samples, const double** impurity, MetaHandlerFunc meta_handler, + LeafHandlerFunc leaf_handler) { TREELITE_CHECK_GT(n_trees, 0); TREELITE_CHECK_GT(n_features, 0); @@ -46,6 +47,7 @@ std::unique_ptr LoadSKLearnModel( const int64_t left_child_id = children_left[tree_id][node_id]; const int64_t right_child_id = children_right[tree_id][node_id]; const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + const double weighted_sample_cnt = weighted_n_node_samples[tree_id][node_id]; if (left_child_id == -1) { // leaf node leaf_handler(tree_id, node_id, new_node_id, value, n_classes, tree); } else { @@ -66,6 +68,7 @@ std::unique_ptr LoadSKLearnModel( Q.push({right_child_id, tree.RightChild(new_node_id)}); } tree.SetDataCount(new_node_id, sample_cnt); + tree.SetSumHess(new_node_id, weighted_sample_cnt); } } return model_ptr; @@ -74,7 +77,8 @@ std::unique_ptr LoadSKLearnModel( std::unique_ptr LoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = true; @@ -92,14 +96,15 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( dest_tree.SetLeaf(new_node_id, leaf_value); }; return LoadSKLearnModel(n_estimators, n_features, 1, node_count, children_left, children_right, - feature, threshold, value, n_node_samples, impurity, meta_handler, leaf_handler); + feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnIsolationForest( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity, - const double ratio_c) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity, const double ratio_c) { auto meta_handler = [ratio_c](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = true; @@ -119,14 +124,15 @@ std::unique_ptr LoadSKLearnIsolationForest( dest_tree.SetLeaf(new_node_id, leaf_value); }; return LoadSKLearnModel(n_estimators, n_features, 1, node_count, children_left, children_right, - feature, threshold, value, n_node_samples, impurity, meta_handler, leaf_handler); + feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnRandomForestClassifierBinary( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = true; @@ -147,15 +153,15 @@ std::unique_ptr LoadSKLearnRandomForestClassifierBinary( dest_tree.SetLeaf(new_node_id, fraction_positive); }; return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, - children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, - leaf_handler); + children_right, feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, + meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = true; @@ -182,30 +188,32 @@ std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( dest_tree.SetLeafVector(new_node_id, prob_distribution); }; return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, - children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, - leaf_handler); + children_right, feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, + meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { TREELITE_CHECK_GE(n_classes, 2); if (n_classes == 2) { return LoadSKLearnRandomForestClassifierBinary(n_estimators, n_features, n_classes, node_count, - children_left, children_right, feature, threshold, value, n_node_samples, impurity); + children_left, children_right, feature, threshold, value, n_node_samples, + weighted_n_node_samples, impurity); } else { return LoadSKLearnRandomForestClassifierMulticlass(n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, threshold, value, n_node_samples, - impurity); + weighted_n_node_samples, impurity); } } std::unique_ptr LoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity) { + const double** value, const int64_t** n_node_samples, const double** weighted_n_node_samples, + const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = false; @@ -223,15 +231,15 @@ std::unique_ptr LoadSKLearnGradientBoostingRegressor( dest_tree.SetLeaf(new_node_id, leaf_value); }; return LoadSKLearnModel(n_estimators, n_features, 1, node_count, children_left, - children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, - leaf_handler); + children_right, feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, + meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = false; @@ -249,15 +257,15 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( dest_tree.SetLeaf(new_node_id, leaf_value); }; return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, - children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, - leaf_handler); + children_right, feature, threshold, value, n_node_samples, weighted_n_node_samples, impurity, + meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { model->num_feature = n_features; model->average_tree_output = false; @@ -275,24 +283,24 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass dest_tree.SetLeaf(new_node_id, leaf_value); }; return LoadSKLearnModel(n_estimators * n_classes, n_features, n_classes, node_count, - children_left, children_right, feature, threshold, value, n_node_samples, impurity, - meta_handler, leaf_handler); + children_left, children_right, feature, threshold, value, n_node_samples, + weighted_n_node_samples, impurity, meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, - const double** impurity) { + const double** weighted_n_node_samples, const double** impurity) { TREELITE_CHECK_GE(n_classes, 2); if (n_classes == 2) { return LoadSKLearnGradientBoostingClassifierBinary(n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, threshold, value, n_node_samples, - impurity); + weighted_n_node_samples, impurity); } else { return LoadSKLearnGradientBoostingClassifierMulticlass(n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, threshold, value, n_node_samples, - impurity); + weighted_n_node_samples, impurity); } }