Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import treelite models into FIL in a different precision #4839

Merged
merged 9 commits into from
Aug 26, 2022
14 changes: 14 additions & 0 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ enum storage_type_t {
};
static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"};

/** precision_t defines the precision of the FIL model imported from a treelite model */
enum precision_t {
/** use the native precision of the treelite model, i.e. float64 if it has weights or
thresholds of type float64, otherwise float32 */
PRECISION_NATIVE,
/** always create a float32 FIL model; this may lead to loss of precision if the
treelite model contains float64 parameters */
PRECISION_FLOAT32,
/** always create a float64 FIL model */
PRECISION_FLOAT64
};

template <typename real_t>
struct forest;

Expand Down Expand Up @@ -113,6 +125,8 @@ struct treelite_params_t {
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
char** pforest_shape_str;
// precision in which to load the treelite model
precision_t precision;
};

/** from_treelite uses a treelite model to initialize the forest
Expand Down
35 changes: 30 additions & 5 deletions cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,38 @@ void from_treelite(const raft::handle_t& handle,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// floating-point type used for model representation
using real_t = decltype(threshold_t(0) + leaf_t(0));
precision_t precision = tl_params->precision;
// choose the precision based on model if required
if (precision == PRECISION_NATIVE) {
precision = std::is_same_v<decltype(threshold_t(0) + leaf_t(0)), float> ?
PRECISION_FLOAT32 : PRECISION_FLOAT64;
}

// get the pointer to the right forest variant
*pforest_variant = (forest_t<real_t>)nullptr;
forest_t<real_t>* pforest = &std::get<forest_t<real_t>>(*pforest_variant);
switch (precision) {
case PRECISION_FLOAT32: {
*pforest_variant = (forest_t<float>)nullptr;
forest_t<float>* pforest = &std::get<forest_t<float>>(*pforest_variant);
from_treelite(handle, pforest, model, tl_params);
break;
}
case PRECISION_FLOAT64: {
*pforest_variant = (forest_t<double>)nullptr;
forest_t<double>* pforest = &std::get<forest_t<double>>(*pforest_variant);
from_treelite(handle, pforest, model, tl_params);
break;
}
default:
ASSERT(false, "bad value of tl_params->precision, must be one of "
"PRECISION_{NATIVE,FLOAT32,FLOAT64}");
}
}

template <typename threshold_t, typename leaf_t, typename real_t>
void from_treelite(const raft::handle_t& handle,
forest_t<real_t>* pforest,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// Invariants on threshold and leaf types
static_assert(type_supported<threshold_t>(),
"Model must contain float32 or float64 thresholds for splits");
Expand Down
1 change: 1 addition & 0 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,7 @@ class TreeliteFilTest : public BaseFilTest<real_t> {
params.threads_per_tree = this->ps.threads_per_tree;
params.n_items = this->ps.n_items;
params.pforest_shape_str = this->ps.print_forest_shape ? &forest_shape_str : nullptr;
params.precision = fil::PRECISION_NATIVE;
fil::forest_variant forest_variant;
fil::from_treelite(this->handle, &forest_variant, (ModelHandle)model.get(), &params);
*pforest = std::get<fil::forest_t<real_t>>(forest_variant);
Expand Down
49 changes: 48 additions & 1 deletion python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil":
SPARSE,
SPARSE8

cdef enum precision_t:
PRECISION_NATIVE,
PRECISION_FLOAT32,
PRECISION_FLOAT64

cdef cppclass forest[real_t]:
pass

Expand Down Expand Up @@ -228,6 +233,7 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil":
int n_items
# this affects inference performance and will become configurable soon
char** pforest_shape_str
precision_t precision

cdef void free[real_t](handle_t& handle,
forest[real_t]*)
Expand Down Expand Up @@ -306,6 +312,18 @@ cdef class ForestInference_impl():
logger.info('storage_type=="sparse8" is an experimental feature')
return storage_type_dict[storage_type_str]

def get_precision(self, precision):
precision_dict = {'native': precision_t.PRECISION_NATIVE,
'float32': precision_t.PRECISION_FLOAT32,
'float64': precision_t.PRECISION_FLOAT64}
if precision not in precision_dict.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if precision not in precision_dict.keys():
if precision not in precision_dict:

to be more Pythonic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

raise ValueError(
"The value entered for precision is not "
"supported. Please refer to the documentation at"
"(https://docs.rapids.ai/api/cuml/nightly/api.html#"
"forest-inferencing) to see the accepted values.")
return precision_dict[precision]

def predict(self, X,
output_dtype=None,
predict_proba=False,
Expand Down Expand Up @@ -433,6 +451,7 @@ cdef class ForestInference_impl():
treelite_params.pforest_shape_str = &self.shape_str
else:
treelite_params.pforest_shape_str = NULL
treelite_params.precision = self.get_precision(kwargs['precision'])

cdef handle_t* handle_ =\
<handle_t*><size_t>self.handle.getHandle()
Expand Down Expand Up @@ -592,6 +611,14 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains float64
weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of precision if
the treelite model contains float64 weights or thresholds
- ``'float64'``: always load in float64
""")
return func

Expand Down Expand Up @@ -681,7 +708,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
):
precision='native'):
"""Creates a FIL model using the treelite model
passed to the function.

Expand Down Expand Up @@ -719,6 +746,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native',
handle=None):
"""
Creates a FIL model using the scikit-learn model passed to the
Expand Down Expand Up @@ -770,6 +798,14 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains float64
weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of precision if
the treelite model contains float64 weights or thresholds
- ``'float64'``: always load in float64

Returns
----------
Expand Down Expand Up @@ -798,6 +834,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native',
model_type="xgboost",
handle=None):
"""
Expand Down Expand Up @@ -852,6 +889,15 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains float64
weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of precision if
the treelite model contains float64 weights or thresholds
- ``'float64'``: always load in float64

model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
Expand Down Expand Up @@ -881,6 +927,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native'
):
"""
Returns a FIL instance by converting a treelite model to
Expand Down
54 changes: 38 additions & 16 deletions python/cuml/tests/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if has_xgboost():
import xgboost as xgb

pytestmark = pytest.mark.skip
# pytestmark = pytest.mark.skip


def simulate_data(m, n, k=2, n_informative='auto', random_state=None,
Expand Down Expand Up @@ -212,21 +212,25 @@ def test_fil_regression(n_rows, n_columns, num_rounds, tmp_path, max_depth):
[(2, False), (2, True), (10, False), (10, True),
(20, True)])
# When n_classes=25, fit a single estimator only to reduce test time
@pytest.mark.parametrize('n_classes,model_class,n_estimators',
[(2, GradientBoostingClassifier, 1),
(2, GradientBoostingClassifier, 10),
(2, RandomForestClassifier, 1),
(5, RandomForestClassifier, 1),
(2, RandomForestClassifier, 10),
(5, RandomForestClassifier, 10),
(2, ExtraTreesClassifier, 1),
(2, ExtraTreesClassifier, 10),
(5, GradientBoostingClassifier, 1),
(5, GradientBoostingClassifier, 10),
(25, GradientBoostingClassifier, 1),
(25, RandomForestClassifier, 1)])
@pytest.mark.parametrize('n_classes,model_class,n_estimators,precision',
[(2, GradientBoostingClassifier, 1, 'native'),
(2, GradientBoostingClassifier, 10, 'native'),
(2, RandomForestClassifier, 1, 'native'),
(5, RandomForestClassifier, 1, 'native'),
(2, RandomForestClassifier, 10, 'native'),
(5, RandomForestClassifier, 10, 'native'),
(2, ExtraTreesClassifier, 1, 'native'),
(2, ExtraTreesClassifier, 10, 'native'),
(5, GradientBoostingClassifier, 1, 'native'),
(5, GradientBoostingClassifier, 10, 'native'),
(25, GradientBoostingClassifier, 1, 'native'),
(25, RandomForestClassifier, 1, 'native'),
(2, RandomForestClassifier, 10, 'float32'),
(2, RandomForestClassifier, 10, 'float64'),
(5, RandomForestClassifier, 10, 'float32'),
(5, RandomForestClassifier, 10, 'float64')])
def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
n_classes, storage_type, model_class):
n_classes, storage_type, precision, model_class):
# settings
classification = True # change this to false to use regression
random_state = np.random.RandomState(43210)
Expand Down Expand Up @@ -266,7 +270,8 @@ def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
algo=algo,
output_class=True,
threshold=0.50,
storage_type=storage_type)
storage_type=storage_type,
precision=precision)
fil_preds = np.asarray(fm.predict(X_validation))
fil_preds = np.reshape(fil_preds, np.shape(skl_preds_int))
fil_acc = accuracy_score(y_validation, fil_preds)
Expand Down Expand Up @@ -389,6 +394,23 @@ def test_output_algos(algo, small_classifier_and_preds):
assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
@pytest.mark.parametrize('precision', ['native', 'float32', 'float64'])
def test_precision_xgboost(precision, small_classifier_and_preds):
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
output_class=True,
threshold=0.50,
precision=precision)

xgb_preds_int = np.around(xgb_preds)
fil_preds = np.asarray(fm.predict(X))
fil_preds = np.reshape(fil_preds, np.shape(xgb_preds_int))

assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
@pytest.mark.parametrize('storage_type',
[False, True, 'auto', 'dense', 'sparse', 'sparse8'])
Expand Down