Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick](branch-21) cherry-pick pr about (#42488) (#42099) (#42055) #42916

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions be/src/vec/functions/function_bit_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <bit>
#include <bitset>

#include "common/status.h"
#include "vec/columns/column.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_number.h"
#include "vec/functions/cast_type_to_either.h"
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {

class FunctionBitTest : public IFunction {
public:
static constexpr auto name = "bit_test";

static FunctionPtr create() { return std::make_shared<FunctionBitTest>(); }

String get_name() const override { return name; }

size_t get_number_of_arguments() const override { return 0; }

bool is_variadic() const override { return true; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return std::make_shared<DataTypeInt8>();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
bool valid =
cast_type(block.get_by_position(arguments[0]).type.get(), [&](const auto& type) {
using DataType = std::decay_t<decltype(type)>;
using T = typename DataType::FieldType;
if (auto col = check_and_get_column<ColumnVector<T>>(
block.get_by_position(arguments[0]).column.get()) ||
is_column_const(*block.get_by_position(arguments[0]).column)) {
execute_inner<T>(block, arguments, result, input_rows_count);
return true;
zhangstar333 marked this conversation as resolved.
Show resolved Hide resolved
}
return false;
});
if (!valid) {
return Status::RuntimeError(
"{}'s argument does not match the expected data type, type: {}, column: {}",
get_name(), block.get_by_position(arguments[0]).type->get_name(),
block.get_by_position(arguments[0]).column->dump_structure());
}
return Status::OK();
}

template <typename F>
static bool cast_type(const IDataType* type, F&& f) {
return cast_type_to_either<DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeInt128>(type, std::forward<F>(f));
}

template <typename T>
void execute_inner(Block& block, const ColumnNumbers& arguments, size_t result,
size_t input_rows_count) const {
size_t argument_size = arguments.size();
std::vector<ColumnPtr> argument_columns(argument_size);
auto result_data_column = ColumnInt8::create(input_rows_count, 1);
auto& res_data = result_data_column->get_data();

// maybe most user is bit_test(column, const), so only handle this case
if (argument_size == 2) {
std::vector<uint8_t> is_consts(argument_size);
std::tie(argument_columns[0], is_consts[0]) =
unpack_if_const(block.get_by_position(arguments[0]).column);
std::tie(argument_columns[1], is_consts[1]) =
unpack_if_const(block.get_by_position(arguments[1]).column);
execute_for_two_argument<T>(argument_columns, is_consts, res_data, input_rows_count);
} else {
for (size_t i = 0; i < argument_size; ++i) {
argument_columns[i] = block.get_by_position(arguments[i])
.column->convert_to_full_column_if_const();
}
execute_for_others_arg<T>(argument_columns, res_data, argument_size, input_rows_count);
}

block.replace_by_position(result, std::move(result_data_column));
}

template <typename T>
void execute_for_two_argument(std::vector<ColumnPtr>& argument_columns,
std::vector<uint8_t>& is_consts, ColumnInt8::Container& res_data,
size_t input_rows_count) const {
const auto& first_column_data =
assert_cast<const ColumnVector<T>&>(*argument_columns[0].get()).get_data();
const auto& second_column_data =
assert_cast<const ColumnVector<T>&>(*argument_columns[1].get()).get_data();
for (int i = 0; i < input_rows_count; ++i) {
auto first_value = first_column_data[index_check_const(i, is_consts[0])];
auto second_value = second_column_data[index_check_const(i, is_consts[1])];
// the pos is invalid, set result = 0
if (second_value < 0 || second_value >= sizeof(T) * 8) {
res_data[i] = 0;
continue;
}
res_data[i] = ((first_value >> second_value) & 1);
}
}

template <typename T>
void execute_for_others_arg(std::vector<ColumnPtr>& argument_columns,
ColumnInt8::Container& res_data, size_t argument_size,
size_t input_rows_count) const {
const auto& first_column_data =
assert_cast<const ColumnVector<T>&>(*argument_columns[0].get()).get_data();
for (int i = 0; i < input_rows_count; ++i) {
auto first_value = first_column_data[i];
for (int col = 1; col < argument_size; ++col) {
const auto& arg_column_data =
assert_cast<const ColumnVector<T>&>(*argument_columns[col].get())
.get_data();
// the pos is invalid, set result = 0
if (arg_column_data[i] < 0 || arg_column_data[i] >= sizeof(T) * 8) {
res_data[i] = 0;
break;
}
// if one of pos & result is 0, could set res = 0, and return directly.
if (!((first_value >> arg_column_data[i]) & 1)) {
res_data[i] = 0;
break;
}
}
}
}
};

void register_function_bit_test(SimpleFunctionFactory& factory) {
factory.register_function<FunctionBitTest>();
factory.register_alias("bit_test", "bit_test_all");
}

} // namespace doris::vectorized
1 change: 1 addition & 0 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,7 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionFromBase64>();
factory.register_function<FunctionSplitPart>();
factory.register_function<FunctionSplitByString>();
factory.register_function<FunctionCountSubString>();
factory.register_function<FunctionSubstringIndex>();
factory.register_function<FunctionExtractURLParameter>();
factory.register_function<FunctionStringParseUrl>();
Expand Down
116 changes: 116 additions & 0 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -2733,6 +2733,122 @@ class FunctionSplitByString : public IFunction {
}
};

class FunctionCountSubString : public IFunction {
public:
static constexpr auto name = "count_substrings";

static FunctionPtr create() { return std::make_shared<FunctionCountSubString>(); }
using NullMapType = PaddedPODArray<UInt8>;

String get_name() const override { return name; }

size_t get_number_of_arguments() const override { return 2; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DCHECK(is_string(arguments[0]))
<< "first argument for function: " << name << " should be string"
<< " and arguments[0] is " << arguments[0]->get_name();
DCHECK(is_string(arguments[1]))
<< "second argument for function: " << name << " should be string"
<< " and arguments[1] is " << arguments[1]->get_name();
return std::make_shared<DataTypeInt32>();
}

Status execute_impl(FunctionContext* /*context*/, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
DCHECK_EQ(arguments.size(), 2);
const auto& [src_column, left_const] =
unpack_if_const(block.get_by_position(arguments[0]).column);
const auto& [right_column, right_const] =
unpack_if_const(block.get_by_position(arguments[1]).column);

const auto* col_left = check_and_get_column<ColumnString>(src_column.get());
if (!col_left) {
return Status::InternalError("Left operator of function {} can not be {}", get_name(),
block.get_by_position(arguments[0]).type->get_name());
}

const auto* col_right = check_and_get_column<ColumnString>(right_column.get());
if (!col_right) {
return Status::InternalError("Right operator of function {} can not be {}", get_name(),
block.get_by_position(arguments[1]).type->get_name());
}

auto dest_column_ptr = ColumnInt32::create(input_rows_count, 0);
// count_substring(ColumnString, "xxx")
if (right_const) {
_execute_constant_pattern(*col_left, col_right->get_data_at(0),
dest_column_ptr->get_data(), input_rows_count);
} else if (left_const) {
// count_substring("xxx", ColumnString)
_execute_constant_src_string(col_left->get_data_at(0), *col_right,
dest_column_ptr->get_data(), input_rows_count);
} else {
// count_substring(ColumnString, ColumnString)
_execute_vector(*col_left, *col_right, dest_column_ptr->get_data(), input_rows_count);
}

block.replace_by_position(result, std::move(dest_column_ptr));
return Status::OK();
}

private:
void _execute_constant_pattern(const ColumnString& src_column_string,
const StringRef& pattern_ref,
ColumnInt32::Container& dest_column_data,
size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; i++) {
const StringRef str_ref = src_column_string.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

void _execute_vector(const ColumnString& src_column_string, const ColumnString& pattern_column,
ColumnInt32::Container& dest_column_data, size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; i++) {
const StringRef pattern_ref = pattern_column.get_data_at(i);
const StringRef str_ref = src_column_string.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

void _execute_constant_src_string(const StringRef& str_ref, const ColumnString& pattern_col,
ColumnInt32::Container& dest_column_data,
size_t input_rows_count) const {
for (size_t i = 0; i < input_rows_count; ++i) {
const StringRef pattern_ref = pattern_col.get_data_at(i);
dest_column_data[i] = find_str_count(str_ref, pattern_ref);
}
}

size_t find_pos(size_t pos, const StringRef str_ref, const StringRef pattern_ref) const {
size_t old_size = pos;
size_t str_size = str_ref.size;
while (pos < str_size && memcmp_small_allow_overflow15(str_ref.data + pos, pattern_ref.data,
pattern_ref.size)) {
pos++;
}
return pos - old_size;
}

int find_str_count(const StringRef str_ref, StringRef pattern_ref) const {
int count = 0;
if (str_ref.size == 0 || pattern_ref.size == 0) {
return 0;
} else {
for (size_t str_pos = 0; str_pos <= str_ref.size;) {
const size_t res_pos = find_pos(str_pos, str_ref, pattern_ref);
if (res_pos == (str_ref.size - str_pos)) {
break; // not find
}
count++;
str_pos = str_pos + res_pos + pattern_ref.size;
}
}
return count;
}
};

struct SM3Sum {
static constexpr auto name = "sm3sum";
using ObjectData = SM3Digest;
Expand Down
2 changes: 2 additions & 0 deletions be/src/vec/functions/simple_function_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void register_function_tokenize(SimpleFunctionFactory& factory);
void register_function_url(SimpleFunctionFactory& factory);
void register_function_ip(SimpleFunctionFactory& factory);
void register_function_multi_match(SimpleFunctionFactory& factory);
void register_function_bit_test(SimpleFunctionFactory& factory);

class SimpleFunctionFactory {
using Creator = std::function<FunctionBuilderPtr()>;
Expand Down Expand Up @@ -297,6 +298,7 @@ class SimpleFunctionFactory {
register_function_ignore(instance);
register_function_variant_element(instance);
register_function_multi_match(instance);
register_function_bit_test(instance);
});
return instance;
}
Expand Down
Loading
Loading