Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SWIG methods that return char** #2850

Merged
merged 7 commits into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,23 +555,39 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle,
/*!
* \brief Get names of evaluation datasets.
* \param handle Handle of booster
* \param len Number of char* pointers stored at char** out_strs.
* If smaller than the max size, only this many strings are copied.
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] out_len Total number of evaluation datasets
* \param buffer_len Size of pre-allocated strings.
* Content is copied up to buffer_len-1 and null-terminated.
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] out_buffer_len String sizes required to do the full string copies
* \param[out] out_strs Names of evaluation datasets, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle,
const int len,
int* out_len,
const size_t buffer_len,
size_t* out_buffer_len,
char** out_strs);

/*!
* \brief Get names of features.
* \param handle Handle of booster
* \param len Number of char* pointers stored at char** out_strs.
* If smaller than the max size, only this many strings are copied.
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] out_len Total number of features
* \param buffer_len Size of pre-allocated strings.
* Content is copied up to buffer_len-1 and null-terminated.
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
* \param[out] out_buffer_len String sizes required to do the full string copies
* \param[out] out_strs Names of features, should pre-allocate memory
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle,
const int len,
int* out_len,
const size_t buffer_len,
size_t* out_buffer_len,
char** out_strs);

/*!
Expand Down
26 changes: 24 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,14 +2711,24 @@ def feature_name(self):
num_feature = self.num_feature()
# Get name of features
tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range_(num_feature)]
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetFeatureNames(
self.handle,
num_feature,
ctypes.byref(tmp_out_len),
reserved_string_buffer_size,
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
"Allocated feature name buffer size ({}) was inferior to the needed size ({})."
.format(reserved_string_buffer_size, required_string_buffer_size.value)
)
return [string_buffers[i].value.decode() for i in range_(num_feature)]

def feature_importance(self, importance_type='split', iteration=None):
Expand Down Expand Up @@ -2898,14 +2908,26 @@ def __get_eval_info(self):
if self.__num_inner_eval > 0:
# Get name of evals
tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range_(self.__num_inner_eval)]
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [
ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
self.__num_inner_eval,
ctypes.byref(tmp_out_len),
reserved_string_buffer_size,
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
"Allocated eval buffer size ({}) was inferior to the needed size ({})."
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
.format(reserved_string_buffer_size, required_string_buffer_size.value)
)
self.__name_inner_eval = \
[string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
Expand Down
40 changes: 32 additions & 8 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,21 +501,31 @@ class Booster {
return ret;
}

int GetEvalNames(char** out_strs) const {
int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
*out_buffer_len = 0;
int idx = 0;
for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) {
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
if (idx < len) {
std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
out_strs[idx][buffer_len-1] = '\0';
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
}
*out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
++idx;
}
}
return idx;
}

int GetFeatureNames(char** out_strs) const {
int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
*out_buffer_len = 0;
int idx = 0;
for (const auto& name : boosting_->FeatureNames()) {
std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
if (idx < len) {
std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
out_strs[idx][buffer_len-1] = '\0';
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
}
*out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
++idx;
}
return idx;
Expand Down Expand Up @@ -1356,17 +1366,31 @@ int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
API_END();
}

int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
int LGBM_BoosterGetEvalNames(
BoosterHandle handle,
const int len,
int* out_len,
const size_t buffer_len,
size_t* out_buffer_len,
char** out_strs)
{
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs);
*out_len = ref_booster->GetEvalNames(out_strs, len, buffer_len, out_buffer_len);
API_END();
}

int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
int LGBM_BoosterGetFeatureNames(
BoosterHandle handle,
const int len,
int* out_len,
const size_t buffer_len,
size_t* out_buffer_len,
char** out_strs)
{
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetFeatureNames(out_strs);
*out_len = ref_booster->GetFeatureNames(out_strs, len, buffer_len, out_buffer_len);
API_END();
}

Expand Down
16 changes: 14 additions & 2 deletions src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,15 +449,27 @@ LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
R_API_BEGIN();
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));

const size_t reserved_string_size = 128;
std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len);
for (int i = 0; i < len; ++i) {
names[i].resize(128);
names[i].resize(reserved_string_size);
ptr_names[i] = names[i].data();
}

int out_len;
CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
size_t required_string_size;
CHECK_CALL(
LGBM_BoosterGetEvalNames(
R_GET_PTR(handle),
len, &out_len,
reserved_string_size, &required_string_size,
ptr_names.data()
)
);
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
CHECK_EQ(out_len, len);
CHECK_GE(reserved_string_size, required_string_size);
auto merge_names = Common::Join<char*>(ptr_names, "\t");
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
R_API_END();
Expand Down
146 changes: 146 additions & 0 deletions swig/StringArray.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#ifndef __STRING_ARRAY_H__
AlbertoEAF marked this conversation as resolved.
Show resolved Hide resolved
#define __STRING_ARRAY_H__

#include <new>
#include <vector>
#include <algorithm>

/**
* Container that manages an array of fixed-length strings.
*
* To be compatible with SWIG's `various.i` extension module,
* the array of pointers to char* must be NULL-terminated:
* [char*, char*, char*, ..., NULL]
* This implies that the length of this array is bigger
* by 1 element than the number of char* it stores.
* I.e., _num_elements == _array.size()-1
*
* The class also takes care of allocation of the underlying
* char* memory.
*/
class StringArray
{
public:
StringArray(size_t num_elements, size_t string_size)
: _string_size(string_size),
_array(num_elements + 1, nullptr)
{
_allocate_strings(num_elements, string_size);
}

~StringArray()
{
_release_strings();
}

/**
* Returns the pointer to the raw array.
* Notice its size is greater than the number of stored strings by 1.
*
* @return char** pointer to raw data (null-terminated).
*/
char **data() noexcept
{
return _array.data();
}

/**
* Return char* from the array of size _string_size+1.
* Notice the last element in _array is already
* considered out of bounds.
*
* @param index Index of the element to retrieve.
* @return pointer or nullptr if index is out of bounds.
*/
char *getitem(size_t index) noexcept
{
if (_in_bounds(index))
return _array[index];
else
return nullptr;
}

/**
* Safely copies the full content data
* into one of the strings in the array.
* If that is not possible, returns error (-1).
*
* @param index index of the string in the array.
* @param content content to store
*
* @return In case index results in out of bounds access,
* or content + 1 (null-terminator byte) doesn't fit
* into the target string (_string_size), it errors out
* and returns -1.
*/
int setitem(size_t index, std::string content) noexcept
{
if (_in_bounds(index) && content.size() < _string_size)
{
std::strcpy(_array[index], content.c_str());
return 0;
} else {
return -1;
}
}

/**
* @return number of stored strings.
*/
size_t get_num_elements() noexcept
{
return _array.size() - 1;
}

private:

/**
* Returns true if and only if within bounds.
* Notice that it excludes the last element of _array (NULL).
*
* @param index index of the element
* @return bool true if within bounds
*/
bool _in_bounds(size_t index) noexcept
{
return index < get_num_elements();
}

/**
* Allocate an array of fixed-length strings.
*
* Since a NULL-terminated array is required by SWIG's `various.i`,
* the size of the array is actually `num_elements + 1` but only
* num_elements are filled.
*
* @param num_elements Number of strings to store in the array.
* @param string_size The size of each string in the array.
*/
void _allocate_strings(int num_elements, int string_size)
{
for (int i = 0; i < num_elements; ++i)
{
// Leave space for \0 terminator:
_array[i] = new (std::nothrow) char[string_size + 1];

// Check memory allocation:
if (! _array[i]) {
_release_strings();
throw std::bad_alloc();
}
}
}

/**
* Deletes the allocated strings.
*/
void _release_strings() noexcept
{
std::for_each(_array.begin(), _array.end(), [](char* c) { delete[] c; });
}

const size_t _string_size;
std::vector<char*> _array;
};

#endif // __STRING_ARRAY_H__
Loading