From 88819281b4a93a05a53e1d45f0e588d51c9ae3c2 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Fri, 26 Aug 2022 18:19:54 +0200 Subject: [PATCH] Import treelite models into FIL in a different precision (#4839) Import treelite models into FIL in a different precision. - e.g. load float64 treelite models as a float32 FIL model, or vice versa Authors: - Andy Adinets (https://github.com/canonizer) - William Hicks (https://github.com/wphicks) Approvers: - Philip Hyunsu Cho (https://github.com/hcho3) - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/4839 --- cpp/include/cuml/fil/fil.h | 14 +++++++++ cpp/src/fil/treelite_import.cu | 36 ++++++++++++++++++---- cpp/test/sg/fil_test.cu | 1 + python/cuml/fil/fil.pyx | 54 ++++++++++++++++++++++++++++++++- python/cuml/tests/test_fil.py | 55 ++++++++++++++++++++++++---------- 5 files changed, 138 insertions(+), 22 deletions(-) diff --git a/cpp/include/cuml/fil/fil.h b/cpp/include/cuml/fil/fil.h index 2d5d786520..a5b0b6b2aa 100644 --- a/cpp/include/cuml/fil/fil.h +++ b/cpp/include/cuml/fil/fil.h @@ -69,6 +69,18 @@ enum storage_type_t { }; static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"}; +/** precision_t defines the precision of the FIL model imported from a treelite model */ +enum precision_t { + /** use the native precision of the treelite model, i.e. float64 if it has weights or + thresholds of type float64, otherwise float32 */ + PRECISION_NATIVE, + /** always create a float32 FIL model; this may lead to loss of precision if the + treelite model contains float64 parameters */ + PRECISION_FLOAT32, + /** always create a float64 FIL model */ + PRECISION_FLOAT64 +}; + template struct forest; @@ -113,6 +125,8 @@ struct treelite_params_t { // if non-nullptr, *pforest_shape_str will be set to caller-owned string that // contains forest shape char** pforest_shape_str; + // precision in which to load the treelite model + precision_t precision; }; /** from_treelite uses a treelite model to initialize the forest diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 2b9e320c95..aeb113dd3b 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -673,13 +673,39 @@ void from_treelite(const raft::handle_t& handle, const tl::ModelImpl& model, const treelite_params_t* tl_params) { - // floating-point type used for model representation - using real_t = decltype(threshold_t(0) + leaf_t(0)); + precision_t precision = tl_params->precision; + // choose the precision based on model if required + if (precision == PRECISION_NATIVE) { + precision = std::is_same_v ? PRECISION_FLOAT32 + : PRECISION_FLOAT64; + } - // get the pointer to the right forest variant - *pforest_variant = (forest_t)nullptr; - forest_t* pforest = &std::get>(*pforest_variant); + switch (precision) { + case PRECISION_FLOAT32: { + *pforest_variant = (forest_t)nullptr; + forest_t* pforest = &std::get>(*pforest_variant); + from_treelite(handle, pforest, model, tl_params); + break; + } + case PRECISION_FLOAT64: { + *pforest_variant = (forest_t)nullptr; + forest_t* pforest = &std::get>(*pforest_variant); + from_treelite(handle, pforest, model, tl_params); + break; + } + default: + ASSERT(false, + "bad value of tl_params->precision, must be one of " + "PRECISION_{NATIVE,FLOAT32,FLOAT64}"); + } +} +template +void from_treelite(const raft::handle_t& handle, + forest_t* pforest, + const tl::ModelImpl& model, + const treelite_params_t* tl_params) +{ // Invariants on threshold and leaf types static_assert(type_supported(), "Model must contain float32 or float64 thresholds for splits"); diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 82955ec8ae..7ab0896055 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -891,6 +891,7 @@ class TreeliteFilTest : public BaseFilTest { params.threads_per_tree = this->ps.threads_per_tree; params.n_items = this->ps.n_items; params.pforest_shape_str = this->ps.print_forest_shape ? &forest_shape_str : nullptr; + params.precision = fil::PRECISION_NATIVE; fil::forest_variant forest_variant; fil::from_treelite(this->handle, &forest_variant, (ModelHandle)model.get(), ¶ms); *pforest = std::get>(forest_variant); diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 8a71368fb9..b4855c3b71 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -197,6 +197,11 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": SPARSE, SPARSE8 + cdef enum precision_t: + PRECISION_NATIVE, + PRECISION_FLOAT32, + PRECISION_FLOAT64 + cdef cppclass forest[real_t]: pass @@ -227,6 +232,7 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": int n_items # this affects inference performance and will become configurable soon char** pforest_shape_str + precision_t precision cdef void free[real_t](handle_t& handle, forest[real_t]*) @@ -305,6 +311,18 @@ cdef class ForestInference_impl(): logger.info('storage_type=="sparse8" is an experimental feature') return storage_type_dict[storage_type_str] + def get_precision(self, precision): + precision_dict = {'native': precision_t.PRECISION_NATIVE, + 'float32': precision_t.PRECISION_FLOAT32, + 'float64': precision_t.PRECISION_FLOAT64} + if precision not in precision_dict: + raise ValueError( + "The value entered for precision is not " + "supported. Please refer to the documentation at" + "(https://docs.rapids.ai/api/cuml/nightly/api.html#" + "forest-inferencing) to see the accepted values.") + return precision_dict[precision] + def predict(self, X, output_dtype=None, predict_proba=False, @@ -432,6 +450,7 @@ cdef class ForestInference_impl(): treelite_params.pforest_shape_str = &self.shape_str else: treelite_params.pforest_shape_str = NULL + treelite_params.precision = self.get_precision(kwargs['precision']) cdef handle_t* handle_ =\ self.handle.getHandle() @@ -591,6 +610,15 @@ class ForestInference(Base, if True or equivalent, creates a ForestInference.shape_str (writes a human-readable forest shape description as a multiline ascii string) + precision : string (default='native') + precision of weights and thresholds of the FIL model loaded from + the treelite model. + + - ``'native'``: load in float64 if the treelite model contains float64 + weights or thresholds, otherwise load in float32 + - ``'float32'``: always load in float32, may lead to loss of precision + if the treelite model contains float64 weights or thresholds + - ``'float64'``: always load in float64 """) return func @@ -680,7 +708,7 @@ class ForestInference(Base, threads_per_tree=1, n_items=0, compute_shape_str=False, - ): + precision='native'): """Creates a FIL model using the treelite model passed to the function. @@ -718,6 +746,7 @@ class ForestInference(Base, threads_per_tree=1, n_items=0, compute_shape_str=False, + precision='native', handle=None): """ Creates a FIL model using the scikit-learn model passed to the @@ -769,6 +798,16 @@ class ForestInference(Base, if True or equivalent, creates a ForestInference.shape_str (writes a human-readable forest shape description as a multiline ascii string) + precision : string (default='native') + precision of weights and thresholds of the FIL model loaded from + the treelite model. + + - ``'native'``: load in float64 if the treelite model contains + float64 weights or thresholds, otherwise load in float32 + - ``'float32'``: always load in float32, may lead to loss of + precision if the treelite model contains float64 weights or + thresholds + - ``'float64'``: always load in float64 Returns ---------- @@ -797,6 +836,7 @@ class ForestInference(Base, threads_per_tree=1, n_items=0, compute_shape_str=False, + precision='native', model_type="xgboost", handle=None): """ @@ -851,6 +891,17 @@ class ForestInference(Base, if True or equivalent, creates a ForestInference.shape_str (writes a human-readable forest shape description as a multiline ascii string) + precision : string (default='native') + precision of weights and thresholds of the FIL model loaded from + the treelite model. + + - ``'native'``: load in float64 if the treelite model contains + float64 weights or thresholds, otherwise load in float32 + - ``'float32'``: always load in float32, may lead to loss of + precision if the treelite model contains float64 weights or + thresholds + - ``'float64'``: always load in float64 + model_type : string (default="xgboost") Format of the saved treelite model to be load. It can be 'xgboost', 'xgboost_json', 'lightgbm'. @@ -880,6 +931,7 @@ class ForestInference(Base, threads_per_tree=1, n_items=0, compute_shape_str=False, + precision='native' ): """ Returns a FIL instance by converting a treelite model to diff --git a/python/cuml/tests/test_fil.py b/python/cuml/tests/test_fil.py index 7765ea883f..0d5ae50bb2 100644 --- a/python/cuml/tests/test_fil.py +++ b/python/cuml/tests/test_fil.py @@ -36,7 +36,7 @@ if has_xgboost(): import xgboost as xgb -pytestmark = pytest.mark.skip +# pytestmark = pytest.mark.skip def simulate_data(m, n, k=2, n_informative='auto', random_state=None, @@ -212,21 +212,26 @@ def test_fil_regression(n_rows, n_columns, num_rounds, tmp_path, max_depth): [(2, False), (2, True), (10, False), (10, True), (20, True)]) # When n_classes=25, fit a single estimator only to reduce test time -@pytest.mark.parametrize('n_classes,model_class,n_estimators', - [(2, GradientBoostingClassifier, 1), - (2, GradientBoostingClassifier, 10), - (2, RandomForestClassifier, 1), - (5, RandomForestClassifier, 1), - (2, RandomForestClassifier, 10), - (5, RandomForestClassifier, 10), - (2, ExtraTreesClassifier, 1), - (2, ExtraTreesClassifier, 10), - (5, GradientBoostingClassifier, 1), - (5, GradientBoostingClassifier, 10), - (25, GradientBoostingClassifier, 1), - (25, RandomForestClassifier, 1)]) +@pytest.mark.parametrize('n_classes,model_class,n_estimators,precision', + [(2, GradientBoostingClassifier, 1, 'native'), + (2, GradientBoostingClassifier, 10, 'native'), + (2, RandomForestClassifier, 1, 'native'), + (5, RandomForestClassifier, 1, 'native'), + (2, RandomForestClassifier, 10, 'native'), + (5, RandomForestClassifier, 10, 'native'), + (2, ExtraTreesClassifier, 1, 'native'), + (2, ExtraTreesClassifier, 10, 'native'), + (5, GradientBoostingClassifier, 1, 'native'), + (5, GradientBoostingClassifier, 10, 'native'), + (25, GradientBoostingClassifier, 1, 'native'), + (25, RandomForestClassifier, 1, 'native'), + (2, RandomForestClassifier, 10, 'float32'), + (2, RandomForestClassifier, 10, 'float64'), + (5, RandomForestClassifier, 10, 'float32'), + (5, RandomForestClassifier, 10, 'float64')]) def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth, - n_classes, storage_type, model_class): + n_classes, storage_type, precision, + model_class): # settings classification = True # change this to false to use regression random_state = np.random.RandomState(43210) @@ -266,7 +271,8 @@ def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth, algo=algo, output_class=True, threshold=0.50, - storage_type=storage_type) + storage_type=storage_type, + precision=precision) fil_preds = np.asarray(fm.predict(X_validation)) fil_preds = np.reshape(fil_preds, np.shape(skl_preds_int)) fil_acc = accuracy_score(y_validation, fil_preds) @@ -389,6 +395,23 @@ def test_output_algos(algo, small_classifier_and_preds): assert np.allclose(fil_preds, xgb_preds_int, 1e-3) +@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") +@pytest.mark.parametrize('precision', ['native', 'float32', 'float64']) +def test_precision_xgboost(precision, small_classifier_and_preds): + model_path, model_type, X, xgb_preds = small_classifier_and_preds + fm = ForestInference.load(model_path, + model_type=model_type, + output_class=True, + threshold=0.50, + precision=precision) + + xgb_preds_int = np.around(xgb_preds) + fil_preds = np.asarray(fm.predict(X)) + fil_preds = np.reshape(fil_preds, np.shape(xgb_preds_int)) + + assert np.allclose(fil_preds, xgb_preds_int, 1e-3) + + @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") @pytest.mark.parametrize('storage_type', [False, True, 'auto', 'dense', 'sparse', 'sparse8'])