Skip to content

Commit

Permalink
ARROW-13780: [Gandiva][UDF] Fix bug in udf space/rpad/lpad
Browse files Browse the repository at this point in the history
- add max/min return length for space/lpad/rpad udfs
- correct return length

Closes apache#11016 from ZMZ91/bugfix/limit_return_chars_count

Authored-by: ZMZ <[email protected]>
Signed-off-by: Pindikura Ravindra <[email protected]>
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
ZMZ91 authored and zhouyuan committed Jun 13, 2022
1 parent e9410e5 commit e43021d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 19 deletions.
95 changes: 76 additions & 19 deletions cpp/src/gandiva/precompiled/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,37 @@ const char* lower_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
return ret;
}

// set max/min str length for space_int32, space_int64, lpad_utf8_int32_utf8
// and rpad_utf8_int32_utf8 to avoid exceptions
static const gdv_int32 max_str_length = 65536;
static const gdv_int32 min_str_length = 0;
// Returns a string of 'n' spaces.
#define SPACE_STR(IN_TYPE) \
GANDIVA_EXPORT \
const char* space_##IN_TYPE(gdv_int64 ctx, gdv_##IN_TYPE n, int32_t* out_len) { \
n = std::min(static_cast<gdv_##IN_TYPE>(max_str_length), n); \
n = std::max(static_cast<gdv_##IN_TYPE>(min_str_length), n); \
gdv_int32 n_times = static_cast<gdv_int32>(n); \
if (n_times <= 0) { \
*out_len = 0; \
return ""; \
} \
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(ctx, n_times)); \
if (ret == nullptr) { \
gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string"); \
*out_len = 0; \
return ""; \
} \
for (int i = 0; i < n_times; i++) { \
ret[i] = ' '; \
} \
*out_len = n_times; \
return ret; \
}

SPACE_STR(int32)
SPACE_STR(int64)

// Reverse a utf8 sequence
FORCE_INLINE
const char* reverse_utf8(gdv_int64 context, const char* data, gdv_int32 data_len,
Expand Down Expand Up @@ -1445,36 +1476,58 @@ const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text,
out_len);
}

FORCE_INLINE
gdv_int32 evaluate_return_char_length(gdv_int32 text_len, gdv_int32 actual_text_len,
gdv_int32 return_length, const char* fill_text,
gdv_int32 fill_text_len) {
gdv_int32 fill_actual_text_len = utf8_length_ignore_invalid(fill_text, fill_text_len);
gdv_int32 repeat_times = (return_length - actual_text_len) / fill_actual_text_len;
gdv_int32 return_char_length = repeat_times * fill_text_len + text_len;
gdv_int32 mod = (return_length - actual_text_len) % fill_actual_text_len;
gdv_int32 char_len = 0;
gdv_int32 fill_index = 0;
for (gdv_int32 i = 0; i < mod; i++) {
char_len = utf8_char_length(fill_text[fill_index]);
fill_index += char_len;
return_char_length += char_len;
}
return return_char_length;
}

FORCE_INLINE
const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len,
gdv_int32 return_length, const char* fill_text,
gdv_int32 fill_text_len, gdv_int32* out_len) {
// if the text length or the defined return length (number of characters to return)
// is <=0, then return an empty string.
return_length = std::min(max_str_length, return_length);
return_length = std::max(min_str_length, return_length);
if (text_len == 0 || return_length <= 0) {
*out_len = 0;
return "";
}

// count the number of utf8 characters on text, ignoring invalid bytes
int text_char_count = utf8_length_ignore_invalid(text, text_len);
int actual_text_len = utf8_length_ignore_invalid(text, text_len);

if (return_length == text_char_count ||
(return_length > text_char_count && fill_text_len == 0)) {
if (return_length == actual_text_len ||
(return_length > actual_text_len && fill_text_len == 0)) {
// case where the return length is same as the text's length, or if it need to
// fill into text but "fill_text" is empty, then return text directly.
*out_len = text_len;
return text;
} else if (return_length < text_char_count) {
} else if (return_length < actual_text_len) {
// case where it truncates the result on return length.
*out_len = utf8_byte_pos(context, text, text_len, return_length);
return text;
} else {
// case (return_length > text_char_count)
// case (return_length > actual_text_len)
// case where it needs to copy "fill_text" on the string left. The total number
// of chars to copy is given by (return_length - text_char_count)
char* ret =
reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length));
// of chars to copy is given by (return_length - actual_text_len)
gdv_int32 return_char_length = evaluate_return_char_length(
text_len, actual_text_len, return_length, fill_text, fill_text_len);
char* ret = reinterpret_cast<gdv_binary>(
gdv_fn_context_arena_malloc(context, return_char_length));
if (ret == nullptr) {
gdv_fn_context_set_error_msg(context,
"Could not allocate memory for output string");
Expand All @@ -1484,12 +1537,12 @@ const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32
// try to fulfill the return string with the "fill_text" continuously
int32_t copied_chars_count = 0;
int32_t copied_chars_position = 0;
while (copied_chars_count < return_length - text_char_count) {
while (copied_chars_count < return_length - actual_text_len) {
int32_t char_len;
int32_t fill_index;
// for each char, evaluate its length to consider it when mem copying
for (fill_index = 0; fill_index < fill_text_len; fill_index += char_len) {
if (copied_chars_count >= return_length - text_char_count) {
if (copied_chars_count >= return_length - actual_text_len) {
break;
}
char_len = utf8_char_length(fill_text[fill_index]);
Expand All @@ -1513,29 +1566,33 @@ const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32
gdv_int32 fill_text_len, gdv_int32* out_len) {
// if the text length or the defined return length (number of characters to return)
// is <=0, then return an empty string.
return_length = std::min(max_str_length, return_length);
return_length = std::max(min_str_length, return_length);
if (text_len == 0 || return_length <= 0) {
*out_len = 0;
return "";
}

// count the number of utf8 characters on text, ignoring invalid bytes
int text_char_count = utf8_length_ignore_invalid(text, text_len);
int actual_text_len = utf8_length_ignore_invalid(text, text_len);

if (return_length == text_char_count ||
(return_length > text_char_count && fill_text_len == 0)) {
if (return_length == actual_text_len ||
(return_length > actual_text_len && fill_text_len == 0)) {
// case where the return length is same as the text's length, or if it need to
// fill into text but "fill_text" is empty, then return text directly.
*out_len = text_len;
return text;
} else if (return_length < text_char_count) {
} else if (return_length < actual_text_len) {
// case where it truncates the result on return length.
*out_len = utf8_byte_pos(context, text, text_len, return_length);
return text;
} else {
// case (return_length > text_char_count)
// case (return_length > actual_text_len)
// case where it needs to copy "fill_text" on the string right
char* ret =
reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length));
gdv_int32 return_char_length = evaluate_return_char_length(
text_len, actual_text_len, return_length, fill_text, fill_text_len);
char* ret = reinterpret_cast<gdv_binary>(
gdv_fn_context_arena_malloc(context, return_char_length));
if (ret == nullptr) {
gdv_fn_context_set_error_msg(context,
"Could not allocate memory for output string");
Expand All @@ -1547,12 +1604,12 @@ const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32
// try to fulfill the return string with the "fill_text" continuously
int32_t copied_chars_count = 0;
int32_t copied_chars_position = 0;
while (text_char_count + copied_chars_count < return_length) {
while (actual_text_len + copied_chars_count < return_length) {
int32_t char_len;
int32_t fill_length;
// for each char, evaluate its length to consider it when mem copying
for (fill_length = 0; fill_length < fill_text_len; fill_length += char_len) {
if (text_char_count + copied_chars_count >= return_length) {
if (actual_text_len + copied_chars_count >= return_length) {
break;
}
char_len = utf8_char_length(fill_text[fill_length]);
Expand Down
53 changes: 53 additions & 0 deletions cpp/src/gandiva/precompiled/string_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,41 @@ TEST(TestStringOps, TestBeginsEnds) {
EXPECT_FALSE(ends_with_utf8_utf8("hello", 5, "sir", 3));
}

TEST(TestStringOps, TestSpace) {
// Space - returns a string with 'n' spaces
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
int32_t out_len = 0;

auto out = space_int32(ctx_ptr, 1, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int32(ctx_ptr, 10, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int32(ctx_ptr, 5, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int32(ctx_ptr, -5, &out_len);
EXPECT_EQ(std::string(out, out_len), "");
out = space_int32(ctx_ptr, 65537, &out_len);
EXPECT_EQ(std::string(out, out_len), std::string(65536, ' '));
out = space_int32(ctx_ptr, 2147483647, &out_len);
EXPECT_EQ(std::string(out, out_len), std::string(65536, ' '));

out = space_int64(ctx_ptr, 2, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int64(ctx_ptr, 9, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int64(ctx_ptr, 4, &out_len);
EXPECT_EQ(std::string(out, out_len), " ");
out = space_int64(ctx_ptr, -5, &out_len);
EXPECT_EQ(std::string(out, out_len), "");
out = space_int64(ctx_ptr, 65536, &out_len);
EXPECT_EQ(std::string(out, out_len), std::string(65536, ' '));
out = space_int64(ctx_ptr, 9223372036854775807, &out_len);
EXPECT_EQ(std::string(out, out_len), std::string(65536, ' '));
out = space_int64(ctx_ptr, -2639077559LL, &out_len);
EXPECT_EQ(std::string(out, out_len), "");
}

TEST(TestStringOps, TestIsSubstr) {
EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "world", 5));
EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "lo wo", 5));
Expand Down Expand Up @@ -739,6 +774,9 @@ TEST(TestStringOps, TestLpadString) {
out_str = lpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "дhello");

out_str = lpad_utf8_int32_utf8(ctx_ptr, "大学路", 9, 65536, "", 3, &out_len);
EXPECT_EQ(out_len, 65536 * 3);

// LPAD function tests - with NO pad text
out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "Test");
Expand All @@ -763,6 +801,12 @@ TEST(TestStringOps, TestLpadString) {

out_str = lpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len);
EXPECT_EQ(std::string(out_str, out_len), " абвгд");

out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 65537, &out_len);
EXPECT_EQ(std::string(out_str, out_len), std::string(65526, ' ') + "TestString");

out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
}

TEST(TestStringOps, TestRpadString) {
Expand Down Expand Up @@ -808,6 +852,9 @@ TEST(TestStringOps, TestRpadString) {
out_str = rpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "helloд");

out_str = rpad_utf8_int32_utf8(ctx_ptr, "大学路", 9, 655360, "哈雷路", 3, &out_len);
EXPECT_EQ(out_len, 65536 * 3);

// RPAD function tests - with NO pad text
out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "Test");
Expand All @@ -832,6 +879,12 @@ TEST(TestStringOps, TestRpadString) {

out_str = rpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "абвгд ");

out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 65537, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "TestString" + std::string(65526, ' '));

out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
}

TEST(TestStringOps, TestRtrim) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/precompiled/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ const char* concat_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1,
gdv_int32 in3_len, bool in3_validity,
const char* in4, gdv_int32 in4_len,
bool in4_validity, gdv_int32* out_len);
const char* space_int32(gdv_int64 ctx, gdv_int32 n, int32_t* out_len);
const char* space_int64(gdv_int64 ctx, gdv_int64 n, int32_t* out_len);
const char* concat_utf8_utf8_utf8_utf8_utf8(
gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity,
const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3,
Expand Down

0 comments on commit e43021d

Please sign in to comment.