From efb87b4e9fbb63da27ca460b99e569563bafc8be Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 6 Aug 2024 07:54:41 -0700 Subject: [PATCH] Add support for XGBoost UBJSON in FIL (#6009) Closes #6008 --------- Co-authored-by: Dante Gama Dessavre --- dependencies.yaml | 1 + notebooks/forest_inference_demo.ipynb | 4 ++-- python/cuml/cuml/experimental/fil/fil.pyx | 18 ++++++++++---- python/cuml/cuml/fil/fil.pyx | 29 ++++++++++++++++------- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/dependencies.yaml b/dependencies.yaml index 4c4232cd5a..14c6c756ba 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -503,3 +503,4 @@ dependencies: - pandas - *scikit_learn - seaborn + - xgboost diff --git a/notebooks/forest_inference_demo.ipynb b/notebooks/forest_inference_demo.ipynb index 3e57c2165e..98b4648323 100644 --- a/notebooks/forest_inference_demo.ipynb +++ b/notebooks/forest_inference_demo.ipynb @@ -275,7 +275,7 @@ " algo='BATCH_TREE_REORG',\n", " output_class=True,\n", " threshold=0.50,\n", - " model_type='xgboost'\n", + " model_type='xgboost_ubj'\n", ")" ] }, @@ -507,7 +507,7 @@ " algo='BATCH_TREE_REORG',\n", " output_class=True,\n", " threshold=0.50,\n", - " model_type='xgboost'\n", + " model_type='xgboost_ubj'\n", " )" ] }, diff --git a/python/cuml/cuml/experimental/fil/fil.pyx b/python/cuml/cuml/experimental/fil/fil.pyx index 7fe59e43a1..247e968e1b 100644 --- a/python/cuml/cuml/experimental/fil/fil.pyx +++ b/python/cuml/cuml/experimental/fil/fil.pyx @@ -803,7 +803,7 @@ class ForestInference(UniversalBase, CMajorInputTagMixin): only for models trained and double precision and when exact conformance between results from FIL and the original training framework is of paramount importance. - model_type : {'xgboost', 'xgboost_json', 'lightgbm', + model_type : {'xgboost_ubj', 'xgboost_json', 'xgboost', 'lightgbm', 'treelite_checkpoint', None }, default=None The serialization format for the model file. If None, a best-effort guess will be made based on the file extension. @@ -841,18 +841,26 @@ class ForestInference(UniversalBase, CMajorInputTagMixin): extension = pathlib.Path(path).suffix if extension == '.json': model_type = 'xgboost_json' + elif extension == '.ubj': + model_type = 'xgboost_ubj' elif extension == '.model': model_type = 'xgboost' elif extension == '.txt': model_type = 'lightgbm' else: model_type = 'treelite_checkpoint' - if model_type == 'treelite_checkpoint': + if model_type == "treelite_checkpoint": tl_model = treelite.frontend.Model.deserialize(path) + elif model_type == "xgboost_ubj": + tl_model = treelite.frontend.load_xgboost_model(path, format_choice="ubjson") + elif model_type == "xgboost_json": + tl_model = treelite.frontend.load_xgboost_model(path, format_choice="json") + elif model_type == "xgboost": + tl_model = treelite.frontend.load_xgboost_model_legacy_binary(path) + elif model_type == "lightgbm": + tl_model = treelite.frontend.load_lightgbm_model(path) else: - tl_model = treelite.frontend.Model.load( - path, model_type - ) + raise ValueError(f"Unknown model type: {model_type}") if default_chunk_size is None: default_chunk_size = threads_per_tree return cls( diff --git a/python/cuml/cuml/fil/fil.pyx b/python/cuml/cuml/fil/fil.pyx index 44c656bef4..2589ec8bd8 100644 --- a/python/cuml/cuml/fil/fil.pyx +++ b/python/cuml/cuml/fil/fil.pyx @@ -51,6 +51,9 @@ cdef extern from "treelite/c_api.h": size_t nitem ctypedef void* TreeliteModelHandle ctypedef void* TreeliteGTILConfigHandle + cdef int TreeliteLoadXGBoostModelUBJSON(const char* filename, + const char* config_json, + TreeliteModelHandle* out) except + cdef int TreeliteLoadXGBoostModelLegacyBinary(const char* filename, const char* config_json, TreeliteModelHandle* out) except + @@ -188,7 +191,7 @@ cdef class TreeliteModel(): return model @classmethod - def from_filename(cls, filename, model_type="xgboost"): + def from_filename(cls, filename, model_type="xgboost_ubj"): """ Returns a TreeliteModel object loaded from `filename` @@ -198,15 +201,15 @@ cdef class TreeliteModel(): Path to treelite model file to load model_type : string - Type of model: 'xgboost', 'xgboost_json', or 'lightgbm' + Type of model: 'xgboost_ubj', 'xgboost_json', 'xgboost' or 'lightgbm' """ cdef bytes filename_bytes = filename.encode("UTF-8") cdef bytes config_bytes = b"{}" cdef TreeliteModelHandle handle cdef int res cdef str err_msg - if model_type == "xgboost": - res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle) + if model_type == "xgboost_ubj": + res = TreeliteLoadXGBoostModelUBJSON(filename_bytes, config_bytes, &handle) if res < 0: err_msg = TreeliteGetLastError().decode("UTF-8") raise RuntimeError(f"Failed to load {filename} ({err_msg})") @@ -215,6 +218,11 @@ cdef class TreeliteModel(): if res < 0: err_msg = TreeliteGetLastError().decode("UTF-8") raise RuntimeError(f"Failed to load {filename} ({err_msg})") + elif model_type == "xgboost": + res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle) + if res < 0: + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load {filename} ({err_msg})") elif model_type == "lightgbm": logger.warn("Treelite currently does not support float64 model" " parameters. Accuracy may degrade slightly relative" @@ -953,7 +961,7 @@ class ForestInference(Base, n_items=0, compute_shape_str=False, precision='native', - model_type="xgboost", + model_type="xgboost_ubj", handle=None): """ Returns a FIL instance containing the forest saved in `filename` @@ -1018,9 +1026,14 @@ class ForestInference(Base, thresholds - ``'float64'``: always load in float64 - model_type : string (default="xgboost") - Format of the saved treelite model to be load. - It can be 'xgboost', 'xgboost_json', 'lightgbm'. + model_type : string (default="xgboost_ubj") + Format of the saved tree model to be load. + It can be one of the following: + + - ``'xgboost_ubj'``: XGBoost model, using the UBJSON format (default in XGBoost 2.1+) + - ``'xgboost_json'``: XGBoost model, using the JSON format + - ``'xgboost'``: XGBoost model, using the legacy binary format + - ``'lightgbm'``: LightGBM model Returns -------