Skip to content

Commit

Permalink
Import treelite models into FIL in a different precision (rapidsai#4839)
Browse files Browse the repository at this point in the history
Import treelite models into FIL in a different precision.

- e.g. load float64 treelite models as a float32 FIL model, or vice versa

Authors:
  - Andy Adinets (https://github.com/canonizer)
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - William Hicks (https://github.com/wphicks)

URL: rapidsai#4839
  • Loading branch information
canonizer authored and divyegala committed Sep 2, 2022
1 parent ac68939 commit dadedda
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 22 deletions.
14 changes: 14 additions & 0 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ enum storage_type_t {
};
static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"};

/** precision_t defines the precision of the FIL model imported from a treelite model */
enum precision_t {
/** use the native precision of the treelite model, i.e. float64 if it has weights or
thresholds of type float64, otherwise float32 */
PRECISION_NATIVE,
/** always create a float32 FIL model; this may lead to loss of precision if the
treelite model contains float64 parameters */
PRECISION_FLOAT32,
/** always create a float64 FIL model */
PRECISION_FLOAT64
};

template <typename real_t>
struct forest;

Expand Down Expand Up @@ -113,6 +125,8 @@ struct treelite_params_t {
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
char** pforest_shape_str;
// precision in which to load the treelite model
precision_t precision;
};

/** from_treelite uses a treelite model to initialize the forest
Expand Down
36 changes: 31 additions & 5 deletions cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,39 @@ void from_treelite(const raft::handle_t& handle,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// floating-point type used for model representation
using real_t = decltype(threshold_t(0) + leaf_t(0));
precision_t precision = tl_params->precision;
// choose the precision based on model if required
if (precision == PRECISION_NATIVE) {
precision = std::is_same_v<decltype(threshold_t(0) + leaf_t(0)), float> ? PRECISION_FLOAT32
: PRECISION_FLOAT64;
}

// get the pointer to the right forest variant
*pforest_variant = (forest_t<real_t>)nullptr;
forest_t<real_t>* pforest = &std::get<forest_t<real_t>>(*pforest_variant);
switch (precision) {
case PRECISION_FLOAT32: {
*pforest_variant = (forest_t<float>)nullptr;
forest_t<float>* pforest = &std::get<forest_t<float>>(*pforest_variant);
from_treelite(handle, pforest, model, tl_params);
break;
}
case PRECISION_FLOAT64: {
*pforest_variant = (forest_t<double>)nullptr;
forest_t<double>* pforest = &std::get<forest_t<double>>(*pforest_variant);
from_treelite(handle, pforest, model, tl_params);
break;
}
default:
ASSERT(false,
"bad value of tl_params->precision, must be one of "
"PRECISION_{NATIVE,FLOAT32,FLOAT64}");
}
}

template <typename threshold_t, typename leaf_t, typename real_t>
void from_treelite(const raft::handle_t& handle,
forest_t<real_t>* pforest,
const tl::ModelImpl<threshold_t, leaf_t>& model,
const treelite_params_t* tl_params)
{
// Invariants on threshold and leaf types
static_assert(type_supported<threshold_t>(),
"Model must contain float32 or float64 thresholds for splits");
Expand Down
1 change: 1 addition & 0 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,7 @@ class TreeliteFilTest : public BaseFilTest<real_t> {
params.threads_per_tree = this->ps.threads_per_tree;
params.n_items = this->ps.n_items;
params.pforest_shape_str = this->ps.print_forest_shape ? &forest_shape_str : nullptr;
params.precision = fil::PRECISION_NATIVE;
fil::forest_variant forest_variant;
fil::from_treelite(this->handle, &forest_variant, (ModelHandle)model.get(), &params);
*pforest = std::get<fil::forest_t<real_t>>(forest_variant);
Expand Down
54 changes: 53 additions & 1 deletion python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil":
SPARSE,
SPARSE8

cdef enum precision_t:
PRECISION_NATIVE,
PRECISION_FLOAT32,
PRECISION_FLOAT64

cdef cppclass forest[real_t]:
pass

Expand Down Expand Up @@ -227,6 +232,7 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil":
int n_items
# this affects inference performance and will become configurable soon
char** pforest_shape_str
precision_t precision

cdef void free[real_t](handle_t& handle,
forest[real_t]*)
Expand Down Expand Up @@ -305,6 +311,18 @@ cdef class ForestInference_impl():
logger.info('storage_type=="sparse8" is an experimental feature')
return storage_type_dict[storage_type_str]

def get_precision(self, precision):
precision_dict = {'native': precision_t.PRECISION_NATIVE,
'float32': precision_t.PRECISION_FLOAT32,
'float64': precision_t.PRECISION_FLOAT64}
if precision not in precision_dict:
raise ValueError(
"The value entered for precision is not "
"supported. Please refer to the documentation at"
"(https://docs.rapids.ai/api/cuml/nightly/api.html#"
"forest-inferencing) to see the accepted values.")
return precision_dict[precision]

def predict(self, X,
output_dtype=None,
predict_proba=False,
Expand Down Expand Up @@ -432,6 +450,7 @@ cdef class ForestInference_impl():
treelite_params.pforest_shape_str = &self.shape_str
else:
treelite_params.pforest_shape_str = NULL
treelite_params.precision = self.get_precision(kwargs['precision'])

cdef handle_t* handle_ =\
<handle_t*><size_t>self.handle.getHandle()
Expand Down Expand Up @@ -591,6 +610,15 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains float64
weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of precision
if the treelite model contains float64 weights or thresholds
- ``'float64'``: always load in float64
""")
return func

Expand Down Expand Up @@ -680,7 +708,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
):
precision='native'):
"""Creates a FIL model using the treelite model
passed to the function.
Expand Down Expand Up @@ -718,6 +746,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native',
handle=None):
"""
Creates a FIL model using the scikit-learn model passed to the
Expand Down Expand Up @@ -769,6 +798,16 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains
float64 weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of
precision if the treelite model contains float64 weights or
thresholds
- ``'float64'``: always load in float64
Returns
----------
Expand Down Expand Up @@ -797,6 +836,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native',
model_type="xgboost",
handle=None):
"""
Expand Down Expand Up @@ -851,6 +891,17 @@ class ForestInference(Base,
if True or equivalent, creates a ForestInference.shape_str
(writes a human-readable forest shape description as a
multiline ascii string)
precision : string (default='native')
precision of weights and thresholds of the FIL model loaded from
the treelite model.
- ``'native'``: load in float64 if the treelite model contains
float64 weights or thresholds, otherwise load in float32
- ``'float32'``: always load in float32, may lead to loss of
precision if the treelite model contains float64 weights or
thresholds
- ``'float64'``: always load in float64
model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
Expand Down Expand Up @@ -880,6 +931,7 @@ class ForestInference(Base,
threads_per_tree=1,
n_items=0,
compute_shape_str=False,
precision='native'
):
"""
Returns a FIL instance by converting a treelite model to
Expand Down
55 changes: 39 additions & 16 deletions python/cuml/tests/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if has_xgboost():
import xgboost as xgb

pytestmark = pytest.mark.skip
# pytestmark = pytest.mark.skip


def simulate_data(m, n, k=2, n_informative='auto', random_state=None,
Expand Down Expand Up @@ -212,21 +212,26 @@ def test_fil_regression(n_rows, n_columns, num_rounds, tmp_path, max_depth):
[(2, False), (2, True), (10, False), (10, True),
(20, True)])
# When n_classes=25, fit a single estimator only to reduce test time
@pytest.mark.parametrize('n_classes,model_class,n_estimators',
[(2, GradientBoostingClassifier, 1),
(2, GradientBoostingClassifier, 10),
(2, RandomForestClassifier, 1),
(5, RandomForestClassifier, 1),
(2, RandomForestClassifier, 10),
(5, RandomForestClassifier, 10),
(2, ExtraTreesClassifier, 1),
(2, ExtraTreesClassifier, 10),
(5, GradientBoostingClassifier, 1),
(5, GradientBoostingClassifier, 10),
(25, GradientBoostingClassifier, 1),
(25, RandomForestClassifier, 1)])
@pytest.mark.parametrize('n_classes,model_class,n_estimators,precision',
[(2, GradientBoostingClassifier, 1, 'native'),
(2, GradientBoostingClassifier, 10, 'native'),
(2, RandomForestClassifier, 1, 'native'),
(5, RandomForestClassifier, 1, 'native'),
(2, RandomForestClassifier, 10, 'native'),
(5, RandomForestClassifier, 10, 'native'),
(2, ExtraTreesClassifier, 1, 'native'),
(2, ExtraTreesClassifier, 10, 'native'),
(5, GradientBoostingClassifier, 1, 'native'),
(5, GradientBoostingClassifier, 10, 'native'),
(25, GradientBoostingClassifier, 1, 'native'),
(25, RandomForestClassifier, 1, 'native'),
(2, RandomForestClassifier, 10, 'float32'),
(2, RandomForestClassifier, 10, 'float64'),
(5, RandomForestClassifier, 10, 'float32'),
(5, RandomForestClassifier, 10, 'float64')])
def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
n_classes, storage_type, model_class):
n_classes, storage_type, precision,
model_class):
# settings
classification = True # change this to false to use regression
random_state = np.random.RandomState(43210)
Expand Down Expand Up @@ -266,7 +271,8 @@ def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
algo=algo,
output_class=True,
threshold=0.50,
storage_type=storage_type)
storage_type=storage_type,
precision=precision)
fil_preds = np.asarray(fm.predict(X_validation))
fil_preds = np.reshape(fil_preds, np.shape(skl_preds_int))
fil_acc = accuracy_score(y_validation, fil_preds)
Expand Down Expand Up @@ -389,6 +395,23 @@ def test_output_algos(algo, small_classifier_and_preds):
assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
@pytest.mark.parametrize('precision', ['native', 'float32', 'float64'])
def test_precision_xgboost(precision, small_classifier_and_preds):
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
output_class=True,
threshold=0.50,
precision=precision)

xgb_preds_int = np.around(xgb_preds)
fil_preds = np.asarray(fm.predict(X))
fil_preds = np.reshape(fil_preds, np.shape(xgb_preds_int))

assert np.allclose(fil_preds, xgb_preds_int, 1e-3)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
@pytest.mark.parametrize('storage_type',
[False, True, 'auto', 'dense', 'sparse', 'sparse8'])
Expand Down

0 comments on commit dadedda

Please sign in to comment.