Skip to content

Commit

Permalink
Update doc to indicate ExtraTree support
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Mar 18, 2021
1 parent 14bd6c1 commit 0de43fe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
7 changes: 4 additions & 3 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,9 @@ class ForestInference(Base,
* A single row of data should fit into the shared memory of a thread
block, which means that more than 12288 features are not supported.
* From sklearn.ensemble, only
{RandomForest,GradientBoosting}{Classifier,Regressor} models are
supported. Other sklearn.ensemble models are currently not supported.
{RandomForest,GradientBoosting,ExtraTrees}{Classifier,Regressor} models
are supported. Other sklearn.ensemble models are currently not
supported.
* Importing large SKLearn models can be slow, as it is done in Python.
* LightGBM categorical features are not supported.
* Inference uses a dense matrix format, which is efficient for many
Expand Down Expand Up @@ -619,7 +620,7 @@ class ForestInference(Base,
handle=None):
"""
Creates a FIL model using the scikit-learn model passed to the
function. This function requires Treelite 0.90 to be installed.
function. This function requires Treelite 1.0.0+ to be installed.
Parameters
----------
Expand Down
11 changes: 8 additions & 3 deletions python/cuml/test/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import GradientBoostingClassifier, \
GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor
GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor, \
ExtraTreesClassifier, ExtraTreesRegressor
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.model_selection import train_test_split

Expand Down Expand Up @@ -213,6 +214,8 @@ def test_fil_regression(n_rows, n_columns, num_rounds, tmp_path, max_depth):
(2, GradientBoostingClassifier, 10),
(2, RandomForestClassifier, 1),
(2, RandomForestClassifier, 10),
(2, ExtraTreesClassifier, 1),
(2, ExtraTreesClassifier, 10),
(5, GradientBoostingClassifier, 1),
(5, GradientBoostingClassifier, 10),
(25, GradientBoostingClassifier, 1)])
Expand All @@ -235,7 +238,7 @@ def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
'n_estimators': n_estimators,
'max_depth': max_depth,
}
if model_class == RandomForestClassifier:
if model_class in [RandomForestClassifier, ExtraTreesClassifier]:
init_kwargs['max_features'] = 0.3
init_kwargs['n_jobs'] = -1
else:
Expand Down Expand Up @@ -283,6 +286,8 @@ def test_fil_skl_classification(n_rows, n_columns, n_estimators, max_depth,
(1, GradientBoostingRegressor, 10),
(1, RandomForestRegressor, 1),
(1, RandomForestRegressor, 10),
(1, ExtraTreesRegressor, 1),
(1, ExtraTreesRegressor, 10),
(5, GradientBoostingRegressor, 10)])
@pytest.mark.parametrize('max_depth', [2, 10, 20])
@pytest.mark.parametrize('storage_type', [False, True])
Expand All @@ -309,7 +314,7 @@ def test_fil_skl_regression(n_rows, n_columns, n_classes, model_class,
'n_estimators': n_estimators,
'max_depth': max_depth,
}
if model_class == RandomForestRegressor:
if model_class in [RandomForestRegressor, ExtraTreesRegressor]:
init_kwargs['max_features'] = 0.3
init_kwargs['n_jobs'] = -1
else:
Expand Down

0 comments on commit 0de43fe

Please sign in to comment.