From da8529ff8eee92d41497cde017a879c318f8faa4 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs Date: Fri, 13 Nov 2020 18:32:54 -0800 Subject: [PATCH] simplify treelite.Model handle extraction --- python/cuml/fil/fil.pyx | 36 +++--------------------------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index c8c9b8e994..8df3ed70d5 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -310,44 +310,14 @@ cdef class ForestInference_impl(): output_dtype=output_dtype ) - def load_from_treelite_model_handle(self, - uintptr_t model_handle, - bool output_class, - str algo, - float threshold, - str storage_type): - cdef treelite_params_t treelite_params - - self.output_class = output_class - treelite_params.output_class = self.output_class - treelite_params.threshold = threshold - treelite_params.algo = self.get_algo(algo) - treelite_params.storage_type = self.get_storage_type(storage_type) - - self.forest_data = NULL - cdef handle_t* handle_ =\ - self.handle.getHandle() - cdef uintptr_t model_ptr = model_handle - - from_treelite(handle_[0], - &self.forest_data, - model_ptr, - &treelite_params) - TreeliteQueryNumOutputGroups( model_ptr, - & self.num_output_groups) - return self - def load_from_treelite_model(self, TreeliteModel model, bool output_class, str algo, float threshold, str storage_type): - TreeliteQueryNumOutputGroups( model.handle, - & self.num_output_groups) - return self.load_from_treelite_model_handle(model.handle, - output_class, algo, - threshold, storage_type) + return self.load_using_treelite_handle(model.handle, output_class, + algo, threshold, storage_type) def load_using_treelite_handle(self, model_handle, @@ -582,7 +552,7 @@ class ForestInference(Base): model, output_class, algo, threshold, str(storage_type)) else: # assume it is treelite.Model - return self._impl.load_from_treelite_model_handle( + return self._impl.load_using_treelite_handle( model.handle.value, output_class, algo, threshold, str(storage_type))