diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index eb9b26a5cc..b8b6f72913 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -50,16 +50,19 @@ cimport cuml.common.cuda cdef extern from "treelite/c_api.h": ctypedef void* ModelHandle - cdef int TreeliteLoadXGBoostModel(const char* filename, - ModelHandle* out) except + - cdef int TreeliteLoadXGBoostJSON(const char* filename, - ModelHandle* out) except + + cdef int TreeliteLoadXGBoostModelEx(const char* filename, + const char* config_json, + ModelHandle* out) except + + cdef int TreeliteLoadXGBoostJSONEx(const char* filename, + const char* config_json, + ModelHandle* out) except + cdef int TreeliteFreeModel(ModelHandle handle) except + cdef int TreeliteQueryNumTree(ModelHandle handle, size_t* out) except + cdef int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) except + cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out) except + - cdef int TreeliteLoadLightGBMModel(const char* filename, - ModelHandle* out) except + + cdef int TreeliteLoadLightGBMModelEx(const char* filename, + const char* config_json, + ModelHandle* out) except + cdef int TreeliteSerializeModel(const char* filename, ModelHandle handle) except + cdef int TreeliteDeserializeModel(const char* filename, @@ -137,14 +140,15 @@ cdef class TreeliteModel(): Type of model: 'xgboost', 'xgboost_json', or 'lightgbm' """ filename_bytes = filename.encode("UTF-8") + config_bytes = "{}".encode("UTF-8") cdef ModelHandle handle if model_type == "xgboost": - res = TreeliteLoadXGBoostModel(filename_bytes, &handle) + res = TreeliteLoadXGBoostModelEx(filename_bytes, config_bytes, &handle) if res < 0: err = TreeliteGetLastError() raise RuntimeError("Failed to load %s (%s)" % (filename, err)) elif model_type == "xgboost_json": - res = TreeliteLoadXGBoostJSON(filename_bytes, &handle) + res = TreeliteLoadXGBoostJSONEx(filename_bytes, config_bytes, &handle) if res < 0: err = TreeliteGetLastError() raise RuntimeError("Failed to load %s (%s)" % (filename, err)) @@ -152,7 +156,7 @@ cdef class TreeliteModel(): logger.warn("Treelite currently does not support float64 model" " parameters. Accuracy may degrade slightly relative" " to native LightGBM invocation.") - res = TreeliteLoadLightGBMModel(filename_bytes, &handle) + res = TreeliteLoadLightGBMModelEx(filename_bytes, config_bytes, &handle) if res < 0: err = TreeliteGetLastError() raise RuntimeError("Failed to load %s (%s)" % (filename, err))