Skip to content

Commit

Permalink
Document model_type='xgboost_json' in FIL (#3633)
Browse files Browse the repository at this point in the history
Closes #3625. Treelite already supports the XGBoost JSON format, so we just need to expose the capability to the FIL loading function too.

Also add `model_type='xgboost_json'` to the FIL testing matrix.

Authors:
  - Philip Hyunsu Cho (@hcho3)

Approvers:
  - John Zedlewski (@JohnZed)

URL: #3633
  • Loading branch information
hcho3 authored Mar 25, 2021
1 parent 9a2bd0c commit 5a82fdb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ cdef class TreeliteModel():
Path to treelite model file to load
model_type : string
Type of model: 'xgboost', or 'lightgbm'
Type of model: 'xgboost', 'xgboost_json', or 'lightgbm'
"""
filename_bytes = filename.encode("UTF-8")
cdef ModelHandle handle
Expand Down Expand Up @@ -728,7 +728,7 @@ class ForestInference(Base,
model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'lightgbm'.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
Returns
----------
Expand Down
26 changes: 17 additions & 9 deletions python/cuml/test/test_fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,22 @@ def test_fil_skl_regression(n_rows, n_columns, n_classes, model_class,
assert np.allclose(fil_preds, skl_preds, 1.2e-3)


@pytest.fixture(scope="session")
def small_classifier_and_preds(tmpdir_factory):
@pytest.fixture(scope="session", params=['binary', 'json'])
def small_classifier_and_preds(tmpdir_factory, request):
X, y = simulate_data(500, 10,
random_state=43210,
classification=True)

model_path = str(tmpdir_factory.mktemp("models").join("small_class.model"))
ext = 'json' if request.param == 'json' else 'model'
model_type = 'xgboost_json' if request.param == 'json' else 'xgboost'
model_path = str(tmpdir_factory.mktemp("models").join(
f"small_class.{ext}"))
bst = _build_and_save_xgboost(model_path, X, y)
# just do within-sample since it's not an accuracy test
dtrain = xgb.DMatrix(X, label=y)
xgb_preds = bst.predict(dtrain)

return (model_path, X, xgb_preds)
return (model_path, model_type, X, xgb_preds)


@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
Expand All @@ -364,8 +367,9 @@ def small_classifier_and_preds(tmpdir_factory):
'auto', 'naive', 'tree_reorg',
'batch_tree_reorg'])
def test_output_algos(algo, small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
algo=algo,
output_class=True,
threshold=0.50)
Expand All @@ -381,8 +385,9 @@ def test_output_algos(algo, small_classifier_and_preds):
@pytest.mark.parametrize('storage_type',
[False, True, 'auto', 'dense', 'sparse', 'sparse8'])
def test_output_storage_type(storage_type, small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
output_class=True,
storage_type=storage_type,
threshold=0.50)
Expand All @@ -399,8 +404,9 @@ def test_output_storage_type(storage_type, small_classifier_and_preds):
@pytest.mark.parametrize('blocks_per_sm', [1, 2, 3, 4])
def test_output_blocks_per_sm(storage_type, blocks_per_sm,
small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
output_class=True,
storage_type=storage_type,
threshold=0.50,
Expand All @@ -416,8 +422,9 @@ def test_output_blocks_per_sm(storage_type, blocks_per_sm,
@pytest.mark.parametrize('output_class', [True, False])
@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
def test_thresholding(output_class, small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
algo='TREE_REORG',
output_class=output_class,
threshold=0.50)
Expand All @@ -430,8 +437,9 @@ def test_thresholding(output_class, small_classifier_and_preds):

@pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost")
def test_output_args(small_classifier_and_preds):
model_path, X, xgb_preds = small_classifier_and_preds
model_path, model_type, X, xgb_preds = small_classifier_and_preds
fm = ForestInference.load(model_path,
model_type=model_type,
algo='TREE_REORG',
output_class=False,
threshold=0.50)
Expand Down

0 comments on commit 5a82fdb

Please sign in to comment.