Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](function) fix Substring/SubReplace error result with input utf8… #40954

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions be/src/util/simd/vstring_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ class VStringFunctions {
}
}

// Iterate a UTF-8 string without exceeding a given length n.
// The function returns two values:
// the first represents the byte length traversed, and the second represents the char length traversed.
static inline std::pair<size_t, size_t> iterate_utf8_with_limit_length(const char* begin,
const char* end,
size_t n) {
const char* p = begin;
int char_size = 0;

size_t i = 0;
for (; i < n && p < end; ++i, p += char_size) {
char_size = UTF8_BYTE_LENGTH[static_cast<uint8_t>(*p)];
}

return {p - begin, i};
}

static void hex_encode(const unsigned char* src_str, size_t length, char* dst_str) {
static constexpr auto hex_table = "0123456789ABCDEF";
auto src_str_end = src_str + length;
Expand Down
134 changes: 96 additions & 38 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ struct SubstringUtil {
const char* str_data = (char*)chars.data() + offsets[i - 1];
int start_value = is_const ? start[0] : start[i];
int len_value = is_const ? len[0] : len[i];

// Unsigned numbers cannot be used here because start_value can be negative.
int char_len = simd::VStringFunctions::get_char_len(str_data, str_size);
// return empty string if start > src.length
if (start_value > str_size || str_size == 0 || start_value == 0 || len_value <= 0) {
// Here, start_value is compared against the length of the character.
if (start_value > char_len || str_size == 0 || start_value == 0 || len_value <= 0) {
StringOP::push_empty_string(i, res_chars, res_offsets);
continue;
}
Expand Down Expand Up @@ -3728,8 +3730,6 @@ class FunctionSubReplace : public IFunction {
return get_variadic_argument_types_impl().size();
}

bool use_default_implementation_for_nulls() const override { return false; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
return Impl::execute_impl(context, block, arguments, result, input_rows_count);
Expand All @@ -3740,59 +3740,116 @@ struct SubReplaceImpl {
static Status replace_execute(Block& block, const ColumnNumbers& arguments, size_t result,
size_t input_rows_count) {
auto res_column = ColumnString::create();
auto result_column = assert_cast<ColumnString*>(res_column.get());
auto* result_column = assert_cast<ColumnString*>(res_column.get());
auto args_null_map = ColumnUInt8::create(input_rows_count, 0);
ColumnPtr argument_columns[4];
bool col_const[4];
for (int i = 0; i < 4; ++i) {
argument_columns[i] =
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
if (auto* nullable = check_and_get_column<ColumnNullable>(*argument_columns[i])) {
// Danger: Here must dispose the null map data first! Because
// argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem
// of column nullable mem of null map
VectorizedUtils::update_null_map(args_null_map->get_data(),
nullable->get_null_map_data());
argument_columns[i] = nullable->get_nested_column_ptr();
}
std::tie(argument_columns[i], col_const[i]) =
unpack_if_const(block.get_by_position(arguments[i]).column);
}

auto data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
auto mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
auto start_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
auto length_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());

vector(data_column, mask_column, start_column->get_data(), length_column->get_data(),
args_null_map->get_data(), result_column, input_rows_count);

const auto* data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
const auto* mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
const auto* start_column =
assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
const auto* length_column =
assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());

std::visit(
[&](auto origin_str_const, auto new_str_const, auto start_const, auto len_const) {
if (simd::VStringFunctions::is_ascii(
StringRef {data_column->get_chars().data(), data_column->size()})) {
vector_ascii<origin_str_const, new_str_const, start_const, len_const>(
data_column, mask_column, start_column->get_data(),
length_column->get_data(), args_null_map->get_data(), result_column,
input_rows_count);
} else {
vector_utf8<origin_str_const, new_str_const, start_const, len_const>(
data_column, mask_column, start_column->get_data(),
length_column->get_data(), args_null_map->get_data(), result_column,
input_rows_count);
}
},
vectorized::make_bool_variant(col_const[0]),
vectorized::make_bool_variant(col_const[1]),
vectorized::make_bool_variant(col_const[2]),
vectorized::make_bool_variant(col_const[3]));
block.get_by_position(result).column =
ColumnNullable::create(std::move(res_column), std::move(args_null_map));
return Status::OK();
}

private:
static void vector(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& start, const PaddedPODArray<Int32>& length,
NullMap& args_null_map, ColumnString* result_column,
size_t input_rows_count) {
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
static void vector_ascii(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& args_start,
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
ColumnString* result_column, size_t input_rows_count) {
ColumnString::Chars& res_chars = result_column->get_chars();
ColumnString::Offsets& res_offsets = result_column->get_offsets();
for (size_t row = 0; row < input_rows_count; ++row) {
StringRef origin_str = data_column->get_data_at(row);
StringRef new_str = mask_column->get_data_at(row);
size_t origin_str_len = origin_str.size;
StringRef origin_str =
data_column->get_data_at(index_check_const<origin_str_const>(row));
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
const auto start = args_start[index_check_const<start_const>(row)];
const auto length = args_length[index_check_const<len_const>(row)];
const size_t origin_str_len = origin_str.size;
//input is null, start < 0, len < 0, str_size <= start. return NULL
if (args_null_map[row] || start[row] < 0 || length[row] < 0 ||
origin_str_len <= start[row]) {
if (args_null_map[row] || start < 0 || length < 0 || origin_str_len <= start) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
} else {
std::string_view replace_str = new_str.to_string_view();
std::string result = origin_str.to_string();
result.replace(start[row], length[row], replace_str);
result.replace(start, length, replace_str);
result_column->insert_data(result.data(), result.length());
}
}
}

template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
static void vector_utf8(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& args_start,
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
ColumnString* result_column, size_t input_rows_count) {
ColumnString::Chars& res_chars = result_column->get_chars();
ColumnString::Offsets& res_offsets = result_column->get_offsets();

for (size_t row = 0; row < input_rows_count; ++row) {
StringRef origin_str =
data_column->get_data_at(index_check_const<origin_str_const>(row));
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
const auto start = args_start[index_check_const<start_const>(row)];
const auto length = args_length[index_check_const<len_const>(row)];
//input is null, start < 0, len < 0 return NULL
if (args_null_map[row] || start < 0 || length < 0) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
continue;
}

const auto [start_byte_len, start_char_len] =
simd::VStringFunctions::iterate_utf8_with_limit_length(origin_str.begin(),
origin_str.end(), start);

// start >= orgin.size
DCHECK(start_char_len <= start);
if (start_byte_len == origin_str.size) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
continue;
}

auto [end_byte_len, end_char_len] =
simd::VStringFunctions::iterate_utf8_with_limit_length(
origin_str.begin() + start_byte_len, origin_str.end(), length);
DCHECK(end_char_len <= length);
std::string_view replace_str = new_str.to_string_view();
std::string result = origin_str.to_string();
result.replace(start_byte_len, end_byte_len, replace_str);
result_column->insert_data(result.data(), result.length());
}
}
};

struct SubReplaceThreeImpl {
Expand All @@ -3809,13 +3866,14 @@ struct SubReplaceThreeImpl {

auto str_col =
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
if (auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
if (const auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
str_col = nullable->get_nested_column_ptr();
}
auto& str_offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();

const auto* str_column = assert_cast<const ColumnString*>(str_col.get());
// use utf8 len
for (int i = 0; i < input_rows_count; ++i) {
strlen_data[i] = str_offset[i] - str_offset[i - 1];
StringRef str_ref = str_column->get_data_at(i);
strlen_data[i] = simd::VStringFunctions::get_char_len(str_ref.data, str_ref.size);
}

block.insert({std::move(params), std::make_shared<DataTypeInt32>(), "strlen"});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,63 @@ tNEW-STRorigin str
-- !sql --
d***is

-- !sub_replace_utf8_sql1 --
你a世界

-- !sub_replace_utf8_sql2 --
你ab界

-- !sub_replace_utf8_sql3 --
你ab

-- !sub_replace_utf8_sql4 --
你abcd我界

-- !sub_replace_utf8_sql5 --
\N

-- !sub_replace_utf8_sql6 --
大家世界

-- !sub_replace_utf8_sql7 --
你大家114514

-- !sub_replace_utf8_sql8 --
\N

-- !sub_replace_utf8_sql9 --
\N

-- !sub_replace_utf8_sql10 --
\N

-- !sub_replace_utf8_sql1 --
你a世界

-- !sub_replace_utf8_sql2 --
你ab界

-- !sub_replace_utf8_sql3 --
你ab

-- !sub_replace_utf8_sql4 --
你abcd我界

-- !sub_replace_utf8_sql5 --
\N

-- !sub_replace_utf8_sql6 --
大家世界

-- !sub_replace_utf8_sql7 --
你大家114514

-- !sub_replace_utf8_sql8 --
\N

-- !sub_replace_utf8_sql9 --
\N

-- !sub_replace_utf8_sql10 --
\N

Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,25 @@ suite("test_string_function") {

qt_sql "select sub_replace(\"this is origin str\",\"NEW-STR\",1);"
qt_sql "select sub_replace(\"doris\",\"***\",1,2);"
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"

}
Loading