Skip to content

Commit

Permalink
Merge pull request #5354 from rapidsai/branch-23.04
Browse files Browse the repository at this point in the history
Forward-merge branch-23.04 to branch-23.06
  • Loading branch information
GPUtester authored Apr 11, 2023
2 parents fb00387 + 6faf82e commit c277e1c
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,19 @@ cimport cuml.common.cuda

cdef extern from "treelite/c_api.h":
ctypedef void* ModelHandle
cdef int TreeliteLoadXGBoostModel(const char* filename,
ModelHandle* out) except +
cdef int TreeliteLoadXGBoostJSON(const char* filename,
ModelHandle* out) except +
cdef int TreeliteLoadXGBoostModelEx(const char* filename,
const char* config_json,
ModelHandle* out) except +
cdef int TreeliteLoadXGBoostJSONEx(const char* filename,
const char* config_json,
ModelHandle* out) except +
cdef int TreeliteFreeModel(ModelHandle handle) except +
cdef int TreeliteQueryNumTree(ModelHandle handle, size_t* out) except +
cdef int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) except +
cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out) except +
cdef int TreeliteLoadLightGBMModel(const char* filename,
ModelHandle* out) except +
cdef int TreeliteLoadLightGBMModelEx(const char* filename,
const char* config_json,
ModelHandle* out) except +
cdef int TreeliteSerializeModel(const char* filename,
ModelHandle handle) except +
cdef int TreeliteDeserializeModel(const char* filename,
Expand Down Expand Up @@ -137,22 +140,23 @@ cdef class TreeliteModel():
Type of model: 'xgboost', 'xgboost_json', or 'lightgbm'
"""
filename_bytes = filename.encode("UTF-8")
config_bytes = "{}".encode("UTF-8")
cdef ModelHandle handle
if model_type == "xgboost":
res = TreeliteLoadXGBoostModel(filename_bytes, &handle)
res = TreeliteLoadXGBoostModelEx(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
elif model_type == "xgboost_json":
res = TreeliteLoadXGBoostJSON(filename_bytes, &handle)
res = TreeliteLoadXGBoostJSONEx(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
elif model_type == "lightgbm":
logger.warn("Treelite currently does not support float64 model"
" parameters. Accuracy may degrade slightly relative"
" to native LightGBM invocation.")
res = TreeliteLoadLightGBMModel(filename_bytes, &handle)
res = TreeliteLoadLightGBMModelEx(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
Expand Down

0 comments on commit c277e1c

Please sign in to comment.