diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 9e2607071979..8e9a6fc2500c 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -555,23 +555,39 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, /*! * \brief Get names of evaluation datasets. * \param handle Handle of booster + * \param len Number of ``char*`` pointers stored at ``out_strs``. + * If smaller than the max size, only this many strings are copied * \param[out] out_len Total number of evaluation datasets + * \param buffer_len Size of pre-allocated strings. + * Content is copied up to ``buffer_len - 1`` and null-terminated + * \param[out] out_buffer_len String sizes required to do the full string copies * \param[out] out_strs Names of evaluation datasets, should pre-allocate memory * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, + const int len, int* out_len, + const size_t buffer_len, + size_t* out_buffer_len, char** out_strs); /*! * \brief Get names of features. * \param handle Handle of booster + * \param len Number of ``char*`` pointers stored at ``out_strs``. + * If smaller than the max size, only this many strings are copied * \param[out] out_len Total number of features + * \param buffer_len Size of pre-allocated strings. + * Content is copied up to ``buffer_len - 1`` and null-terminated + * \param[out] out_buffer_len String sizes required to do the full string copies * \param[out] out_strs Names of features, should pre-allocate memory * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, + const int len, int* out_len, + const size_t buffer_len, + size_t* out_buffer_len, char** out_strs); /*! diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 6fcb37ff3c89..cd3288126474 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2711,14 +2711,24 @@ def feature_name(self): num_feature = self.num_feature() # Get name of features tmp_out_len = ctypes.c_int(0) - string_buffers = [ctypes.create_string_buffer(255) for i in range_(num_feature)] + reserved_string_buffer_size = 255 + required_string_buffer_size = ctypes.c_size_t(0) + string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)] ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) _safe_call(_LIB.LGBM_BoosterGetFeatureNames( self.handle, + num_feature, ctypes.byref(tmp_out_len), + reserved_string_buffer_size, + ctypes.byref(required_string_buffer_size), ptr_string_buffers)) if num_feature != tmp_out_len.value: raise ValueError("Length of feature names doesn't equal with num_feature") + if reserved_string_buffer_size < required_string_buffer_size.value: + raise BufferError( + "Allocated feature name buffer size ({}) was inferior to the needed size ({})." + .format(reserved_string_buffer_size, required_string_buffer_size.value) + ) return [string_buffers[i].value.decode() for i in range_(num_feature)] def feature_importance(self, importance_type='split', iteration=None): @@ -2898,14 +2908,26 @@ def __get_eval_info(self): if self.__num_inner_eval > 0: # Get name of evals tmp_out_len = ctypes.c_int(0) - string_buffers = [ctypes.create_string_buffer(255) for i in range_(self.__num_inner_eval)] + reserved_string_buffer_size = 255 + required_string_buffer_size = ctypes.c_size_t(0) + string_buffers = [ + ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(self.__num_inner_eval) + ] ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) _safe_call(_LIB.LGBM_BoosterGetEvalNames( self.handle, + self.__num_inner_eval, ctypes.byref(tmp_out_len), + reserved_string_buffer_size, + ctypes.byref(required_string_buffer_size), ptr_string_buffers)) if self.__num_inner_eval != tmp_out_len.value: raise ValueError("Length of eval names doesn't equal with num_evals") + if reserved_string_buffer_size < required_string_buffer_size.value: + raise BufferError( + "Allocated eval name buffer size ({}) was inferior to the needed size ({})." + .format(reserved_string_buffer_size, required_string_buffer_size.value) + ) self.__name_inner_eval = \ [string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)] self.__higher_better_inner_eval = \ diff --git a/src/c_api.cpp b/src/c_api.cpp index 224c74d6c730..22d53937e337 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -501,21 +501,31 @@ class Booster { return ret; } - int GetEvalNames(char** out_strs) const { + int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const { + *out_buffer_len = 0; int idx = 0; for (const auto& metric : train_metric_) { for (const auto& name : metric->GetName()) { - std::memcpy(out_strs[idx], name.c_str(), name.size() + 1); + if (idx < len) { + std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len)); + out_strs[idx][buffer_len - 1] = '\0'; + } + *out_buffer_len = std::max(name.size() + 1, *out_buffer_len); ++idx; } } return idx; } - int GetFeatureNames(char** out_strs) const { + int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const { + *out_buffer_len = 0; int idx = 0; for (const auto& name : boosting_->FeatureNames()) { - std::memcpy(out_strs[idx], name.c_str(), name.size() + 1); + if (idx < len) { + std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len)); + out_strs[idx][buffer_len - 1] = '\0'; + } + *out_buffer_len = std::max(name.size() + 1, *out_buffer_len); ++idx; } return idx; @@ -1356,17 +1366,27 @@ int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) { API_END(); } -int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) { +int LGBM_BoosterGetEvalNames(BoosterHandle handle, + const int len, + int* out_len, + const size_t buffer_len, + size_t* out_buffer_len, + char** out_strs) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - *out_len = ref_booster->GetEvalNames(out_strs); + *out_len = ref_booster->GetEvalNames(out_strs, len, buffer_len, out_buffer_len); API_END(); } -int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) { +int LGBM_BoosterGetFeatureNames(BoosterHandle handle, + const int len, + int* out_len, + const size_t buffer_len, + size_t* out_buffer_len, + char** out_strs) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - *out_len = ref_booster->GetFeatureNames(out_strs); + *out_len = ref_booster->GetFeatureNames(out_strs, len, buffer_len, out_buffer_len); API_END(); } diff --git a/src/lightgbm_R.cpp b/src/lightgbm_R.cpp index 4ba70a5f1c22..9cb7a62d6947 100644 --- a/src/lightgbm_R.cpp +++ b/src/lightgbm_R.cpp @@ -449,15 +449,25 @@ LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle, R_API_BEGIN(); int len; CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len)); + + const size_t reserved_string_size = 128; std::vector> names(len); std::vector ptr_names(len); for (int i = 0; i < len; ++i) { - names[i].resize(128); + names[i].resize(reserved_string_size); ptr_names[i] = names[i].data(); } + int out_len; - CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data())); + size_t required_string_size; + CHECK_CALL( + LGBM_BoosterGetEvalNames( + R_GET_PTR(handle), + len, &out_len, + reserved_string_size, &required_string_size, + ptr_names.data())); CHECK_EQ(out_len, len); + CHECK_GE(reserved_string_size, required_string_size); auto merge_names = Common::Join(ptr_names, "\t"); EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1); R_API_END(); diff --git a/swig/StringArray.hpp b/swig/StringArray.hpp new file mode 100644 index 000000000000..4fef62a125e2 --- /dev/null +++ b/swig/StringArray.hpp @@ -0,0 +1,150 @@ +/*! + * Copyright (c) 2020 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef __STRING_ARRAY_H__ +#define __STRING_ARRAY_H__ + +#include +#include +#include + +/** + * Container that manages an array of fixed-length strings. + * + * To be compatible with SWIG's `various.i` extension module, + * the array of pointers to char* must be NULL-terminated: + * [char*, char*, char*, ..., NULL] + * This implies that the length of this array is bigger + * by 1 element than the number of char* it stores. + * I.e., _num_elements == _array.size()-1 + * + * The class also takes care of allocation of the underlying + * char* memory. + */ +class StringArray +{ + public: + StringArray(size_t num_elements, size_t string_size) + : _string_size(string_size), + _array(num_elements + 1, nullptr) + { + _allocate_strings(num_elements, string_size); + } + + ~StringArray() + { + _release_strings(); + } + + /** + * Returns the pointer to the raw array. + * Notice its size is greater than the number of stored strings by 1. + * + * @return char** pointer to raw data (null-terminated). + */ + char **data() noexcept + { + return _array.data(); + } + + /** + * Return char* from the array of size _string_size+1. + * Notice the last element in _array is already + * considered out of bounds. + * + * @param index Index of the element to retrieve. + * @return pointer or nullptr if index is out of bounds. + */ + char *getitem(size_t index) noexcept + { + if (_in_bounds(index)) + return _array[index]; + else + return nullptr; + } + + /** + * Safely copies the full content data + * into one of the strings in the array. + * If that is not possible, returns error (-1). + * + * @param index index of the string in the array. + * @param content content to store + * + * @return In case index results in out of bounds access, + * or content + 1 (null-terminator byte) doesn't fit + * into the target string (_string_size), it errors out + * and returns -1. + */ + int setitem(size_t index, std::string content) noexcept + { + if (_in_bounds(index) && content.size() < _string_size) + { + std::strcpy(_array[index], content.c_str()); + return 0; + } else { + return -1; + } + } + + /** + * @return number of stored strings. + */ + size_t get_num_elements() noexcept + { + return _array.size() - 1; + } + + private: + + /** + * Returns true if and only if within bounds. + * Notice that it excludes the last element of _array (NULL). + * + * @param index index of the element + * @return bool true if within bounds + */ + bool _in_bounds(size_t index) noexcept + { + return index < get_num_elements(); + } + + /** + * Allocate an array of fixed-length strings. + * + * Since a NULL-terminated array is required by SWIG's `various.i`, + * the size of the array is actually `num_elements + 1` but only + * num_elements are filled. + * + * @param num_elements Number of strings to store in the array. + * @param string_size The size of each string in the array. + */ + void _allocate_strings(int num_elements, int string_size) + { + for (int i = 0; i < num_elements; ++i) + { + // Leave space for \0 terminator: + _array[i] = new (std::nothrow) char[string_size + 1]; + + // Check memory allocation: + if (! _array[i]) { + _release_strings(); + throw std::bad_alloc(); + } + } + } + + /** + * Deletes the allocated strings. + */ + void _release_strings() noexcept + { + std::for_each(_array.begin(), _array.end(), [](char* c) { delete[] c; }); + } + + const size_t _string_size; + std::vector _array; +}; + +#endif // __STRING_ARRAY_H__ diff --git a/swig/StringArray.i b/swig/StringArray.i new file mode 100644 index 000000000000..12f1a5d2b21c --- /dev/null +++ b/swig/StringArray.i @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2020 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +/** + * This wraps the StringArray.hpp class for SWIG usage, + * adding the basic C-style wrappers needed to make it + * usable for the users of the low-level lightgbmJNI API. + */ + +%{ +#include "../swig/StringArray.hpp" +%} + +// Use SWIG's `various.i` to get a String[] directly in one call: +%apply char **STRING_ARRAY {char **StringArrayHandle_get_strings}; + +%inline %{ + + typedef void* StringArrayHandle; + + /** + * @brief Creates a new StringArray and returns its handle. + * + * @param num_strings number of strings to store. + * @param string_size the maximum number of characters that can be stored in each string. + * @return StringArrayHandle or nullptr in case of allocation failure. + */ + StringArrayHandle StringArrayHandle_create(size_t num_strings, size_t string_size) { + try { + return new StringArray(num_strings, string_size); + } catch (std::bad_alloc &e) { + return nullptr; + } + } + + /** + * @brief Free the StringArray object. + * + * @param handle StringArray handle. + */ + void StringArrayHandle_free(StringArrayHandle handle) + { + delete reinterpret_cast(handle); + } + + /** + * @brief Return the raw pointer to the array of strings. + * Wrapped in Java into String[] automatically. + * + * @param handle StringArray handle. + * @return Raw pointer to the string array which `various.i` maps to String[]. + */ + char **StringArrayHandle_get_strings(StringArrayHandle handle) + { + return reinterpret_cast(handle)->data(); + } + + /** + * For the end user to extract a specific string from the StringArray object. + * + * @param handle StringArray handle. + * @param index index of the string to retrieve from the array. + * @return raw pointer to string at index, or nullptr if out of bounds. + */ + char *StringArrayHandle_get_string(StringArrayHandle handle, int index) + { + return reinterpret_cast(handle)->getitem(index); + } + + /** + * @brief Replaces one string of the array at index with the new content. + * + * @param handle StringArray handle. + * @param index Index of the string to replace + * @param new_content The content to replace + * @return 0 (success) or -1 (error) in case of out of bounds index or too large content. + */ + int StringArrayHandle_set_string(StringArrayHandle handle, size_t index, const char* new_content) + { + return reinterpret_cast(handle)->setitem(index, std::string(new_content)); + } + + /** + * @brief Retrieve the number of strings in the StringArray. + * + * @param handle StringArray handle. + * @return number of strings that the array stores. + */ + size_t StringArrayHandle_get_num_elements(StringArrayHandle handle) + { + return reinterpret_cast(handle)->get_num_elements(); + } + +%} diff --git a/swig/StringArray_API_extensions.i b/swig/StringArray_API_extensions.i new file mode 100644 index 000000000000..b44f8f263793 --- /dev/null +++ b/swig/StringArray_API_extensions.i @@ -0,0 +1,107 @@ +/*! + * Copyright (c) 2020 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +/** + * This SWIG interface extension provides support to + * allocate, return and manage arrays of strings through + * the class StringArray. + * + * This is then used to generate wrappers that return newly-allocated + * arrays of strings, so the user can them access them easily as a String[] + * on the Java side by a single call to StringArray::data(), and even manipulate + * them. + * + * It also implements working wrappers to: + * - LGBM_BoosterGetEvalNames (re-implemented with new API) + * - LGBM_BoosterGetFeatureNames (original non-wrapped version didn't work). + * where the wrappers names end with "SWIG". + */ + + +#include + +%include "./StringArray.i" + +%inline %{ + + #define API_OK_OR_VALUE(api_return, return_value) if (api_return == -1) return return_value + #define API_OK_OR_NULL(api_return) API_OK_OR_VALUE(api_return, nullptr) + + /** + * @brief Wraps LGBM_BoosterGetEvalNames. + * + * In case of success a new StringArray is created and returned, + * which you're responsible for freeing, + * @see StringArrayHandle_free(). + * In case of failure such resource is freed and nullptr is returned. + * Check for that case with null (lightgbmlib) or 0 (lightgbmlibJNI). + * + * @param handle Booster handle + * @return StringArrayHandle with the eval names (or nullptr in case of error) + */ + StringArrayHandle LGBM_BoosterGetEvalNamesSWIG(BoosterHandle handle) + { + int eval_counts; + size_t string_size; + std::unique_ptr strings(nullptr); + + // Retrieve required allocation space: + API_OK_OR_NULL(LGBM_BoosterGetEvalNames(handle, + 0, &eval_counts, + 0, &string_size, + strings->data())); + + try { + strings.reset(new StringArray(eval_counts, string_size)); + } catch (std::bad_alloc &e) { + LGBM_SetLastError("Failure to allocate memory."); + return nullptr; + } + + API_OK_OR_NULL(LGBM_BoosterGetEvalNames(handle, + eval_counts, &eval_counts, + string_size, &string_size, + strings->data())); + + return strings.release(); + } + + /** + * @brief Wraps LGBM_BoosterGetFeatureNames. + * + * Allocates a new StringArray. You must free it yourself if it succeeds. + * @see StringArrayHandle_free(). + * In case of failure such resource is freed and nullptr is returned. + * Check for that case with null (lightgbmlib) or 0 (lightgbmlibJNI). + * + * @param handle Booster handle + * @return StringArrayHandle with the feature names (or nullptr in case of error) + */ + StringArrayHandle LGBM_BoosterGetFeatureNamesSWIG(BoosterHandle handle) + { + int num_features; + size_t max_feature_name_size; + std::unique_ptr strings(nullptr); + + // Retrieve required allocation space: + API_OK_OR_NULL(LGBM_BoosterGetFeatureNames(handle, + 0, &num_features, + 0, &max_feature_name_size, + nullptr)); + + try { + strings.reset(new StringArray(num_features, max_feature_name_size)); + } catch (std::bad_alloc &e) { + LGBM_SetLastError("Failure to allocate memory."); + return nullptr; + } + + API_OK_OR_NULL(LGBM_BoosterGetFeatureNames(handle, + num_features, &num_features, + max_feature_name_size, &max_feature_name_size, + strings->data())); + + return strings.release(); + } +%} diff --git a/swig/lightgbmlib.i b/swig/lightgbmlib.i index 34468dd805d1..5fba17e12adf 100644 --- a/swig/lightgbmlib.i +++ b/swig/lightgbmlib.i @@ -6,6 +6,7 @@ %module lightgbmlib %ignore LGBM_BoosterSaveModelToString; %ignore LGBM_BoosterGetEvalNames; +%ignore LGBM_BoosterGetFeatureNames; %{ /* Includes the header in the wrapper code */ #include "../include/LightGBM/export.h" @@ -73,19 +74,6 @@ return dst; } - char ** LGBM_BoosterGetEvalNamesSWIG(BoosterHandle handle, - int eval_counts) { - char** dst = new char*[eval_counts]; - for (int i = 0; i < eval_counts; ++i) { - dst[i] = new char[128]; - } - int result = LGBM_BoosterGetEvalNames(handle, &eval_counts, dst); - if (result != 0) { - return nullptr; - } - return dst; - } - int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv, jdoubleArray data, BoosterHandle handle, @@ -249,10 +237,6 @@ %array_functions(float, floatArray) %array_functions(int, intArray) %array_functions(long, longArray) -/* Note: there is a bug in the SWIG generated string arrays when creating - a new array with strings where the strings are prematurely deallocated -*/ -%array_functions(char *, stringArray) /* Custom pointer manipulation template */ %define %pointer_manipulation(TYPE, NAME) @@ -301,3 +285,5 @@ TYPE *NAME##_handle(); /* Allow retrieving handle to void** */ %pointer_handle(void*, voidpp) + +%include "StringArray_API_extensions.i"