diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 37d96f8b8a..56e290ca60 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -116,13 +116,13 @@ cdef class TreeliteModel(): TreeliteQueryNumFeature(self.handle, &out) return out - @staticmethod - def free_treelite_model(model_handle): + @classmethod + def free_treelite_model(cls, model_handle): cdef uintptr_t model_ptr = model_handle TreeliteFreeModel( model_ptr) - @staticmethod - def from_filename(filename, model_type="xgboost"): + @classmethod + def from_filename(cls, filename, model_type="xgboost"): """ Returns a TreeliteModel object loaded from `filename` @@ -173,8 +173,9 @@ cdef class TreeliteModel(): filename_bytes = filename.encode("UTF-8") TreeliteSerializeModel(filename_bytes, self.handle) - @staticmethod - def from_treelite_model_handle(treelite_handle, + @classmethod + def from_treelite_model_handle(cls, + treelite_handle, take_handle_ownership=False): cdef ModelHandle handle = treelite_handle model = TreeliteModel(owns_handle=take_handle_ownership) @@ -843,8 +844,9 @@ class ForestInference(Base, ) return cuml_fm - @staticmethod - def load(filename, + @classmethod + def load(cls, + filename, output_class=False, threshold=0.50, algo='auto',