diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 26ddaaa08b..92c3884cd7 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -120,7 +120,7 @@ cdef class TreeliteModel(): Path to treelite model file to load model_type : string - Type of model: 'xgboost', or 'lightgbm' + Type of model: 'xgboost', 'xgboost_json', or 'lightgbm' """ filename_bytes = filename.encode("UTF-8") cdef ModelHandle handle @@ -728,7 +728,7 @@ class ForestInference(Base, model_type : string (default="xgboost") Format of the saved treelite model to be load. - It can be 'xgboost', 'lightgbm'. + It can be 'xgboost', 'xgboost_json', 'lightgbm'. Returns ---------- diff --git a/python/cuml/test/test_fil.py b/python/cuml/test/test_fil.py index afbe440870..52a6850d11 100644 --- a/python/cuml/test/test_fil.py +++ b/python/cuml/test/test_fil.py @@ -343,19 +343,22 @@ def test_fil_skl_regression(n_rows, n_columns, n_classes, model_class, assert np.allclose(fil_preds, skl_preds, 1.2e-3) -@pytest.fixture(scope="session") -def small_classifier_and_preds(tmpdir_factory): +@pytest.fixture(scope="session", params=['binary', 'json']) +def small_classifier_and_preds(tmpdir_factory, request): X, y = simulate_data(500, 10, random_state=43210, classification=True) - model_path = str(tmpdir_factory.mktemp("models").join("small_class.model")) + ext = 'json' if request.param == 'json' else 'model' + model_type = 'xgboost_json' if request.param == 'json' else 'xgboost' + model_path = str(tmpdir_factory.mktemp("models").join( + f"small_class.{ext}")) bst = _build_and_save_xgboost(model_path, X, y) # just do within-sample since it's not an accuracy test dtrain = xgb.DMatrix(X, label=y) xgb_preds = bst.predict(dtrain) - return (model_path, X, xgb_preds) + return (model_path, model_type, X, xgb_preds) @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") @@ -364,8 +367,9 @@ def small_classifier_and_preds(tmpdir_factory): 'auto', 'naive', 'tree_reorg', 'batch_tree_reorg']) def test_output_algos(algo, small_classifier_and_preds): - model_path, X, xgb_preds = small_classifier_and_preds + model_path, model_type, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, + model_type=model_type, algo=algo, output_class=True, threshold=0.50) @@ -381,8 +385,9 @@ def test_output_algos(algo, small_classifier_and_preds): @pytest.mark.parametrize('storage_type', [False, True, 'auto', 'dense', 'sparse', 'sparse8']) def test_output_storage_type(storage_type, small_classifier_and_preds): - model_path, X, xgb_preds = small_classifier_and_preds + model_path, model_type, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, + model_type=model_type, output_class=True, storage_type=storage_type, threshold=0.50) @@ -399,8 +404,9 @@ def test_output_storage_type(storage_type, small_classifier_and_preds): @pytest.mark.parametrize('blocks_per_sm', [1, 2, 3, 4]) def test_output_blocks_per_sm(storage_type, blocks_per_sm, small_classifier_and_preds): - model_path, X, xgb_preds = small_classifier_and_preds + model_path, model_type, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, + model_type=model_type, output_class=True, storage_type=storage_type, threshold=0.50, @@ -416,8 +422,9 @@ def test_output_blocks_per_sm(storage_type, blocks_per_sm, @pytest.mark.parametrize('output_class', [True, False]) @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") def test_thresholding(output_class, small_classifier_and_preds): - model_path, X, xgb_preds = small_classifier_and_preds + model_path, model_type, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, + model_type=model_type, algo='TREE_REORG', output_class=output_class, threshold=0.50) @@ -430,8 +437,9 @@ def test_thresholding(output_class, small_classifier_and_preds): @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") def test_output_args(small_classifier_and_preds): - model_path, X, xgb_preds = small_classifier_and_preds + model_path, model_type, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, + model_type=model_type, algo='TREE_REORG', output_class=False, threshold=0.50)