Skip to content

Commit

Permalink
Add support for XGBoost UBJSON in FIL (#6009)
Browse files Browse the repository at this point in the history
Closes #6008

---------

Co-authored-by: Dante Gama Dessavre <[email protected]>
  • Loading branch information
hcho3 and dantegd authored Aug 6, 2024
1 parent 50d1a74 commit efb87b4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 15 deletions.
1 change: 1 addition & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,4 @@ dependencies:
- pandas
- *scikit_learn
- seaborn
- xgboost
4 changes: 2 additions & 2 deletions notebooks/forest_inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
" algo='BATCH_TREE_REORG',\n",
" output_class=True,\n",
" threshold=0.50,\n",
" model_type='xgboost'\n",
" model_type='xgboost_ubj'\n",
")"
]
},
Expand Down Expand Up @@ -507,7 +507,7 @@
" algo='BATCH_TREE_REORG',\n",
" output_class=True,\n",
" threshold=0.50,\n",
" model_type='xgboost'\n",
" model_type='xgboost_ubj'\n",
" )"
]
},
Expand Down
18 changes: 13 additions & 5 deletions python/cuml/cuml/experimental/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class ForestInference(UniversalBase, CMajorInputTagMixin):
only for models trained and double precision and when exact
conformance between results from FIL and the original training
framework is of paramount importance.
model_type : {'xgboost', 'xgboost_json', 'lightgbm',
model_type : {'xgboost_ubj', 'xgboost_json', 'xgboost', 'lightgbm',
'treelite_checkpoint', None }, default=None
The serialization format for the model file. If None, a best-effort
guess will be made based on the file extension.
Expand Down Expand Up @@ -841,18 +841,26 @@ class ForestInference(UniversalBase, CMajorInputTagMixin):
extension = pathlib.Path(path).suffix
if extension == '.json':
model_type = 'xgboost_json'
elif extension == '.ubj':
model_type = 'xgboost_ubj'
elif extension == '.model':
model_type = 'xgboost'
elif extension == '.txt':
model_type = 'lightgbm'
else:
model_type = 'treelite_checkpoint'
if model_type == 'treelite_checkpoint':
if model_type == "treelite_checkpoint":
tl_model = treelite.frontend.Model.deserialize(path)
elif model_type == "xgboost_ubj":
tl_model = treelite.frontend.load_xgboost_model(path, format_choice="ubjson")
elif model_type == "xgboost_json":
tl_model = treelite.frontend.load_xgboost_model(path, format_choice="json")
elif model_type == "xgboost":
tl_model = treelite.frontend.load_xgboost_model_legacy_binary(path)
elif model_type == "lightgbm":
tl_model = treelite.frontend.load_lightgbm_model(path)
else:
tl_model = treelite.frontend.Model.load(
path, model_type
)
raise ValueError(f"Unknown model type: {model_type}")
if default_chunk_size is None:
default_chunk_size = threads_per_tree
return cls(
Expand Down
29 changes: 21 additions & 8 deletions python/cuml/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ cdef extern from "treelite/c_api.h":
size_t nitem
ctypedef void* TreeliteModelHandle
ctypedef void* TreeliteGTILConfigHandle
cdef int TreeliteLoadXGBoostModelUBJSON(const char* filename,
const char* config_json,
TreeliteModelHandle* out) except +
cdef int TreeliteLoadXGBoostModelLegacyBinary(const char* filename,
const char* config_json,
TreeliteModelHandle* out) except +
Expand Down Expand Up @@ -188,7 +191,7 @@ cdef class TreeliteModel():
return model

@classmethod
def from_filename(cls, filename, model_type="xgboost"):
def from_filename(cls, filename, model_type="xgboost_ubj"):
"""
Returns a TreeliteModel object loaded from `filename`
Expand All @@ -198,15 +201,15 @@ cdef class TreeliteModel():
Path to treelite model file to load
model_type : string
Type of model: 'xgboost', 'xgboost_json', or 'lightgbm'
Type of model: 'xgboost_ubj', 'xgboost_json', 'xgboost' or 'lightgbm'
"""
cdef bytes filename_bytes = filename.encode("UTF-8")
cdef bytes config_bytes = b"{}"
cdef TreeliteModelHandle handle
cdef int res
cdef str err_msg
if model_type == "xgboost":
res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle)
if model_type == "xgboost_ubj":
res = TreeliteLoadXGBoostModelUBJSON(filename_bytes, config_bytes, &handle)
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
Expand All @@ -215,6 +218,11 @@ cdef class TreeliteModel():
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "xgboost":
res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle)
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "lightgbm":
logger.warn("Treelite currently does not support float64 model"
" parameters. Accuracy may degrade slightly relative"
Expand Down Expand Up @@ -953,7 +961,7 @@ class ForestInference(Base,
n_items=0,
compute_shape_str=False,
precision='native',
model_type="xgboost",
model_type="xgboost_ubj",
handle=None):
"""
Returns a FIL instance containing the forest saved in `filename`
Expand Down Expand Up @@ -1018,9 +1026,14 @@ class ForestInference(Base,
thresholds
- ``'float64'``: always load in float64
model_type : string (default="xgboost")
Format of the saved treelite model to be load.
It can be 'xgboost', 'xgboost_json', 'lightgbm'.
model_type : string (default="xgboost_ubj")
Format of the saved tree model to be load.
It can be one of the following:
- ``'xgboost_ubj'``: XGBoost model, using the UBJSON format (default in XGBoost 2.1+)
- ``'xgboost_json'``: XGBoost model, using the JSON format
- ``'xgboost'``: XGBoost model, using the legacy binary format
- ``'lightgbm'``: LightGBM model
Returns
-------
Expand Down

0 comments on commit efb87b4

Please sign in to comment.