From c1cbd1879d311fcffd85a333d707f1bef377f2fa Mon Sep 17 00:00:00 2001 From: eyes_on_me Date: Tue, 22 Oct 2024 13:59:37 +0800 Subject: [PATCH] [Enhancement][Refactor] optimize array_contains_all/array_contains_seq function (#51701) Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> (cherry picked from commit ac623b3de5fbd936660b9a5316ed04100aef2939) # Conflicts: # be/src/column/column_helper.h --- be/src/column/column_helper.cpp | 11 + be/src/column/column_helper.h | 5 + be/src/exprs/array_functions.h | 32 ++ be/src/exprs/array_functions.tpp | 448 ++++++++++++++++++ be/src/util/bit_mask.h | 8 + .../com/starrocks/catalog/FunctionSet.java | 3 + .../analyzer/DecimalV3FunctionAnalyzer.java | 8 +- gensrc/script/functions.py | 36 ++ test/sql/test_array_fn/R/test_array_contains | 412 ++++++++++++++++ test/sql/test_array_fn/T/test_array_contains | 75 ++- 10 files changed, 1035 insertions(+), 3 deletions(-) diff --git a/be/src/column/column_helper.cpp b/be/src/column/column_helper.cpp index 2cb9675f9288e..a2929fa759205 100644 --- a/be/src/column/column_helper.cpp +++ b/be/src/column/column_helper.cpp @@ -440,6 +440,17 @@ ColumnPtr ColumnHelper::convert_time_column_from_double_to_str(const ColumnPtr& return res; } +std::tuple ColumnHelper::unpack_array_column(const ColumnPtr& column) { + DCHECK(!column->is_nullable() && !column->is_constant()); + DCHECK(column->is_array()); + + const ArrayColumn* array_column = down_cast(column.get()); + auto elements_column = down_cast(array_column->elements_column().get())->data_column(); + auto null_column = down_cast(array_column->elements_column().get())->null_column(); + auto offsets_column = array_column->offsets_column(); + return {offsets_column, elements_column, null_column}; +} + template bool ChunkSliceTemplate::empty() const { return !chunk || offset == chunk->num_rows(); diff --git a/be/src/column/column_helper.h b/be/src/column/column_helper.h index a9c7dfb6e8242..64114e8a6335a 100644 --- a/be/src/column/column_helper.h +++ b/be/src/column/column_helper.h @@ -531,9 +531,14 @@ class ColumnHelper { static ColumnPtr convert_time_column_from_double_to_str(const ColumnPtr& column); +<<<<<<< HEAD static NullColumnPtr one_size_not_null_column; static NullColumnPtr one_size_null_column; +======= + // unpack array column, return offsets_column, elements_column, elements_null_column + static std::tuple unpack_array_column(const ColumnPtr& column); +>>>>>>> ac623b3de5 ([Enhancement][Refactor] optimize array_contains_all/array_contains_seq function (#51701)) }; // Hold a slice of chunk diff --git a/be/src/exprs/array_functions.h b/be/src/exprs/array_functions.h index 827f2156a4655..e49a9f52e106a 100644 --- a/be/src/exprs/array_functions.h +++ b/be/src/exprs/array_functions.h @@ -151,12 +151,44 @@ class ArrayFunctions { DEFINE_VECTORIZED_FN(array_cum_sum_double); DEFINE_VECTORIZED_FN(array_contains_any); + DEFINE_VECTORIZED_FN(array_contains_all); + + template + static StatusOr array_contains_all_specific(FunctionContext* context, const Columns& columns) { + return ArrayContainsAll::process(context, columns); + } + template + static Status array_contains_all_specific_prepare(FunctionContext* context, + FunctionContext::FunctionStateScope scope) { + return ArrayContainsAll::prepare(context, scope); + } + template + static Status array_contains_all_specific_close(FunctionContext* context, + FunctionContext::FunctionStateScope scope) { + return ArrayContainsAll::close(context, scope); + } + DEFINE_VECTORIZED_FN(array_map); DEFINE_VECTORIZED_FN(array_filter); DEFINE_VECTORIZED_FN(all_match); DEFINE_VECTORIZED_FN(any_match); + DEFINE_VECTORIZED_FN(array_contains_seq); + template + static StatusOr array_contains_seq_specific(FunctionContext* context, const Columns& columns) { + return ArrayContainsAll::process(context, columns); + } + template + static Status array_contains_seq_specific_prepare(FunctionContext* context, + FunctionContext::FunctionStateScope scope) { + return ArrayContainsAll::prepare(context, scope); + } + template + static Status array_contains_seq_specific_close(FunctionContext* context, + FunctionContext::FunctionStateScope scope) { + return ArrayContainsAll::close(context, scope); + } // array function for nested type(Array/Map/Struct) DEFINE_VECTORIZED_FN(array_distinct_any_type); diff --git a/be/src/exprs/array_functions.tpp b/be/src/exprs/array_functions.tpp index f078592adaa04..98156e271864a 100644 --- a/be/src/exprs/array_functions.tpp +++ b/be/src/exprs/array_functions.tpp @@ -27,6 +27,7 @@ #include "runtime/current_thread.h" #include "runtime/runtime_state.h" #include "types/logical_type.h" +#include "util/bit_mask.h" #include "util/orlp/pdqsort.h" #include "util/phmap/phmap.h" @@ -1947,4 +1948,451 @@ private: } }; +// Implementation of array_contains_all and array_contains_seq +// for array_contains_all, we build hash table to speed up the search. +// for array_contains_seq, we use the idea of ​​KMP algorithm to speed up the search. +template +class ArrayContainsAll { + using CppType = RunTimeCppType; + using ColumnType = RunTimeColumnType; + using HashFunc = PhmapDefaultHashFunc; + using HashMap = phmap::flat_hash_map; + using PrefixTable = std::vector; + + struct ArrayContainsAllState { + bool has_null = false; + // final result, only used when both two inputs are constant + bool contains = false; + std::variant variant; + }; + +public: + static Status prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { + if (scope != FunctionContext::FRAGMENT_LOCAL) { + return Status::OK(); + } + + if constexpr (!HashFunc::is_supported()) { + return Status::OK(); + } + + bool is_left_notnull_const = context->is_notnull_constant_column(0); + bool is_right_notnull_const = context->is_notnull_constant_column(1); + + if constexpr (ContainsSeq) { + if (!is_right_notnull_const) { + return Status::OK(); + } + } else { + if (!is_left_notnull_const && !is_right_notnull_const) { + return Status::OK(); + } + } + + auto* state = new ArrayContainsAllState(); + context->set_function_state(scope, state); + + ColumnPtr column; + if constexpr (ContainsSeq) { + column = context->get_constant_column(1); + } else { + // for array_contains_all, prefer to use the left column to build hash table + column = is_left_notnull_const ? context->get_constant_column(0) : context->get_constant_column(1); + } + ColumnPtr array_column = FunctionHelper::get_data_column_of_const(column); + const auto& [offsets_column, elements_column, null_column] = ColumnHelper::unpack_array_column(array_column); + const CppType* elements_data = reinterpret_cast(elements_column->raw_data()); + const NullColumn::ValueType* null_data = null_column->raw_data(); + const UInt32Column::ValueType* offsets_data = offsets_column->get_data().data(); + size_t offset = offsets_data[0]; + size_t array_size = offsets_data[1] - offset; + + if constexpr (ContainsSeq) { + state->variant = PrefixTable{}; + _build_prefix_table(elements_data, null_data, offset, array_size, state); + } else { + state->variant = HashMap{}; + _build_hash_table(elements_data, null_data, offset, array_size, state); + } + + if (is_left_notnull_const && is_right_notnull_const) { + // if both inputs are constant, we just compute result directly + ColumnPtr target_column = ContainsSeq ? context->get_constant_column(0) : context->get_constant_column(1); + const auto& [target_offsets_column, target_elements_column, target_null_column] = + ColumnHelper::unpack_array_column(FunctionHelper::get_data_column_of_const(target_column)); + + const CppType* target_elements_data = reinterpret_cast(target_elements_column->raw_data()); + const NullColumn::ValueType* target_elements_null_data = target_null_column->raw_data(); + const UInt32Column::ValueType* target_offsets_data = target_offsets_column->get_data().data(); + + size_t target_offset = target_offsets_data[0]; + size_t target_array_size = target_offsets_data[1] - offset; + if constexpr (ContainsSeq) { + state->contains = _process_with_prefix_table(state, target_elements_data, elements_data, + target_elements_null_data, null_data, target_offset, + target_array_size, offset, array_size); + } else { + state->contains = _process_with_hash_table(state, target_elements_data, target_elements_null_data, + target_offset, target_array_size); + } + } + + return Status::OK(); + } + + static Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) { + if (scope == FunctionContext::FRAGMENT_LOCAL) { + auto* state = reinterpret_cast( + context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); + delete state; + } + return Status::OK(); + } + + static StatusOr process(FunctionContext* context, const Columns& columns) { + if constexpr (!is_supported(LT)) { + return Status::NotSupported(fmt::format("not support type {}", LT)); + } + RETURN_IF_COLUMNS_ONLY_NULL(columns); + + const ColumnPtr& left_column = columns[0]; + const ColumnPtr& right_column = columns[1]; + bool is_const_left = left_column->is_constant(); + bool is_const_right = right_column->is_constant(); + [[maybe_unused]] auto* state = + reinterpret_cast(context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); + + if (is_const_left && is_const_right) { + // if both input columns are constant, return result directly + auto result_column = BooleanColumn::create(); + result_column->append(state->contains); + return ConstColumn::create(std::move(result_column), columns[0]->size()); + } + + NullColumnPtr result_null_column = nullptr; + + ColumnPtr left_data_column = FunctionHelper::get_data_column_of_const(left_column); + bool is_nullable_left = left_data_column->is_nullable(); + const NullColumn::ValueType* left_null_data = nullptr; + if (is_nullable_left) { + left_null_data = down_cast(left_data_column.get())->null_column_data().data(); + left_data_column = down_cast(left_data_column.get())->data_column(); + } + + ColumnPtr right_data_column = FunctionHelper::get_data_column_of_const(right_column); + const NullColumn::ValueType* right_null_data = nullptr; + bool is_nullable_right = right_data_column->is_nullable(); + if (is_nullable_right) { + right_null_data = down_cast(right_data_column.get())->null_column_data().data(); + right_data_column = down_cast(right_data_column.get())->data_column(); + } + + if (is_nullable_left && is_nullable_right) { + return _process(state, left_data_column, right_data_column, left_null_data, right_null_data, + is_const_left, is_const_right); + } else if (is_nullable_left && !is_nullable_right) { + return _process(state, left_data_column, right_data_column, left_null_data, right_null_data, + is_const_left, is_const_right); + } else if (!is_nullable_left && is_nullable_right) { + return _process(state, left_data_column, right_data_column, left_null_data, right_null_data, + is_const_left, is_const_right); + } else { + return _process(state, left_data_column, right_data_column, left_null_data, right_null_data, + is_const_left, is_const_right); + } + } + +private: + static constexpr bool is_supported(LogicalType type) { return is_scalar_logical_type(type); } + + static void _build_hash_table(const CppType* elements_data, const NullColumn::ValueType* elements_null_data, + size_t offset, size_t array_size, ArrayContainsAllState* state) { + HashMap* hash_map = std::get_if(&(state->variant)); + DCHECK(hash_map != nullptr); + + size_t count = 0; + for (size_t i = 0; i < array_size; i++) { + if (elements_null_data[offset + i]) { + state->has_null = true; + continue; + } + const auto& value = elements_data[offset + i]; + if (!hash_map->contains(value)) { + hash_map->insert({value, count++}); + } + } + } + + template + static bool _process_with_hash_table(const ArrayContainsAllState* state, const CppType* elements_data, + const NullColumn::ValueType* elements_null_data, size_t offset, + size_t array_size) { + const HashMap* hash_map = std::get_if(&(state->variant)); + DCHECK(hash_map != nullptr); + // for array_contains_all(left, right), hash table may be built from the left or the right. + // if ht comes from left side, all the data of right side must be found in ht. + // if ht comes from right side, all the data in ht must be appeared in the left side. + + if (hash_map->empty()) { + if (state->has_null) { + size_t null_elements_num = SIMD::count_nonzero(elements_null_data + offset, array_size); + return HTFromLeft ? null_elements_num == array_size : null_elements_num > 0; + } else { + return HTFromLeft ? array_size == 0 : true; + } + } + + if constexpr (HTFromLeft) { + for (size_t i = 0; i < array_size; i++) { + if (elements_null_data[i + offset]) { + if (!state->has_null) { + return false; + } + continue; + } + const auto& value = elements_data[i + offset]; + if (!hash_map->contains(value)) { + return false; + } + } + } else { + BitMask bit_mask(hash_map->size()); + size_t find_count = 0; + bool has_null = false; + for (size_t i = 0; i < array_size; i++) { + if (elements_null_data[i + offset]) { + has_null = true; + continue; + } + const auto& value = elements_data[i + offset]; + auto iter = hash_map->find(value); + if (iter != hash_map->end()) { + size_t idx = iter->second; + find_count += bit_mask.try_set_bit(idx); + } + } + if (!(has_null == state->has_null && find_count == hash_map->size())) { + return false; + } + } + + return true; + } + + static inline bool _check_element_equal(const CppType* left_data, const NullColumn::ValueType* left_null_data, + const CppType* right_data, const NullColumn::ValueType* right_null_data, + size_t lhs, size_t rhs) { + bool is_lhs_null = left_null_data[lhs]; + bool is_rhs_null = right_null_data[rhs]; + if (is_lhs_null ^ is_rhs_null) { + return false; + } + if (is_lhs_null & is_rhs_null) { + return true; + } + return left_data[lhs] == right_data[rhs]; + } + + static void _build_prefix_table(const CppType* elements_data, const NullColumn::ValueType* null_data, size_t offset, + size_t array_size, ArrayContainsAllState* state) { + if (array_size == 0) { + return; + } + PrefixTable* prefix_table = std::get_if(&(state->variant)); + DCHECK(prefix_table != nullptr); + prefix_table->resize(array_size); + + (*prefix_table)[0] = 0; + size_t length = 0; + size_t idx = 1; + while (idx < array_size) { + if (_check_element_equal(elements_data, null_data, elements_data, null_data, offset + idx, + offset + length)) { + length++; + (*prefix_table)[idx] = length; + idx++; + } else { + if (length != 0) { + length = (*prefix_table)[length - 1]; + } else { + (*prefix_table)[idx] = 0; + idx++; + } + } + } + } + + static bool _process_with_prefix_table(const ArrayContainsAllState* state, const CppType* left_elements_data, + const CppType* right_elements_data, + const NullColumn::ValueType* left_elements_null_data, + const NullColumn::ValueType* right_elements_null_data, size_t left_offset, + size_t left_array_size, size_t right_offset, size_t right_array_size) { + if (right_array_size == 0) { + return true; + } + if (right_array_size > left_array_size) { + return false; + } + const PrefixTable* prefix_table = std::get_if(&(state->variant)); + DCHECK(prefix_table != nullptr && !prefix_table->empty()); + DCHECK_EQ(prefix_table->size(), right_array_size); + + size_t left_idx = 0; + size_t right_idx = 0; + while (left_idx < left_array_size) { + bool is_equal = + _check_element_equal(left_elements_data, left_elements_null_data, right_elements_data, + right_elements_null_data, left_offset + left_idx, right_offset + right_idx); + if (is_equal) { + left_idx++; + right_idx++; + } + if (right_idx == right_array_size) { + return true; + } else if (left_idx < left_array_size && + !_check_element_equal(left_elements_data, left_elements_null_data, right_elements_data, + right_elements_null_data, left_offset + left_idx, + right_offset + right_idx)) { + if (right_idx != 0) { + right_idx = (*prefix_table)[right_idx - 1]; + } else { + left_idx++; + } + } + } + return false; + } + + template + static ColumnPtr _process(const ArrayContainsAllState* state, const ColumnPtr& left_arrays, + const ColumnPtr& right_arrays, const NullColumn::ValueType* left_null_data, + const NullColumn::ValueType* right_null_data, bool is_const_left, bool is_const_right) { + DCHECK(!left_arrays->is_constant() && !left_arrays->is_nullable() && left_arrays->is_array()); + DCHECK(!right_arrays->is_constant() && !right_arrays->is_nullable() && right_arrays->is_array()); + if (!is_const_left && !is_const_right) { + DCHECK_EQ(left_arrays->size(), right_arrays->size()); + } + + const auto& [left_offsets_column, left_elements_column, left_elements_null_column] = + ColumnHelper::unpack_array_column(left_arrays); + const CppType* left_elements_data = reinterpret_cast(left_elements_column->raw_data()); + const NullColumn::ValueType* left_elements_null_data = left_elements_null_column->get_data().data(); + const auto* left_offsets_data = left_offsets_column->get_data().data(); + + const auto& [right_offsets_column, right_elements_column, right_elements_null_column] = + ColumnHelper::unpack_array_column(right_arrays); + const CppType* right_elements_data = reinterpret_cast(right_elements_column->raw_data()); + const NullColumn::ValueType* right_elements_null_data = right_elements_null_column->get_data().data(); + const auto* right_offsets_data = right_offsets_column->get_data().data(); + + size_t num_rows = (is_const_left && is_const_right) ? 1 : std::max(left_arrays->size(), right_arrays->size()); + + auto result_column = BooleanColumn::create(); + result_column->resize(num_rows); + auto* result_data = result_column->get_data().data(); + + [[maybe_unused]] NullColumnPtr result_null_column; + [[maybe_unused]] NullColumn::ValueType* result_null_data = nullptr; + if constexpr (NullableLeft || NullableRight) { + result_null_column = NullColumn::create(); + result_null_column->resize(num_rows); + result_null_data = result_null_column->get_data().data(); + } + + for (size_t i = 0; i < num_rows; i++) { + if constexpr (NullableLeft) { + bool is_array_null = is_const_left ? left_null_data[0] : left_null_data[i]; + if (is_array_null) { + result_data[i] = false; + result_null_data[i] = 1; + continue; + } + } + if constexpr (NullableRight) { + bool is_array_null = is_const_right ? right_null_data[0] : right_null_data[i]; + if (is_array_null) { + result_data[i] = false; + result_null_data[i] = 1; + continue; + } + } + + size_t left_array_offset = is_const_left ? left_offsets_data[0] : left_offsets_data[i]; + size_t left_array_size = is_const_left ? left_offsets_data[1] - left_offsets_data[0] + : left_offsets_data[i + 1] - left_offsets_data[i]; + size_t left_null_element_num = + NullableLeft ? 0 + : SIMD::count_nonzero(left_elements_null_data + left_array_offset, left_array_size); + size_t left_not_null_element_num = left_array_size - left_null_element_num; + + size_t right_array_offset = is_const_right ? right_offsets_data[0] : right_offsets_data[i]; + size_t right_array_size = is_const_right ? right_offsets_data[1] - right_offsets_data[0] + : right_offsets_data[i + 1] - right_offsets_data[i]; + size_t right_null_element_num = + NullableRight + ? 0 + : SIMD::count_nonzero(right_elements_null_data + right_array_offset, right_array_size); + size_t right_not_null_element_num = right_array_size - right_null_element_num; + + [[maybe_unused]] const ArrayContainsAllState* state_ref = nullptr; + [[maybe_unused]] ArrayContainsAllState tmp_state; + if constexpr (ContainsSeq) { + if (is_const_right) { + state_ref = state; + } else { + tmp_state.variant = PrefixTable{}; + _build_prefix_table(right_elements_data, right_elements_null_data, right_array_offset, + right_array_size, &tmp_state); + state_ref = &tmp_state; + } + result_data[i] = + _process_with_prefix_table(state_ref, left_elements_data, right_elements_data, + left_elements_null_data, right_elements_null_data, left_array_offset, + left_array_size, right_array_offset, right_array_size); + } else { + bool build_from_left; + + if (is_const_left || is_const_right) { + state_ref = state; + build_from_left = is_const_left; + } else { + tmp_state.variant = HashMap{}; + // we build hash table on the side with less elements + build_from_left = left_not_null_element_num <= right_not_null_element_num; + const CppType* build_elements_data = build_from_left ? left_elements_data : right_elements_data; + const NullColumn::ValueType* build_elements_null_data = + build_from_left ? left_elements_null_data : right_elements_null_data; + size_t build_array_offset = build_from_left ? left_array_offset : right_array_offset; + size_t build_array_size = build_from_left ? left_array_size : right_array_size; + + _build_hash_table(build_elements_data, build_elements_null_data, build_array_offset, + build_array_size, &tmp_state); + state_ref = &tmp_state; + } + + const CppType* probe_elements_data = !build_from_left ? left_elements_data : right_elements_data; + const NullColumn::ValueType* probe_elements_null_data = + !build_from_left ? left_elements_null_data : right_elements_null_data; + size_t probe_array_offset = !build_from_left ? left_array_offset : right_array_offset; + size_t probe_array_size = !build_from_left ? left_array_size : right_array_size; + + result_data[i] = build_from_left ? _process_with_hash_table(state_ref, probe_elements_data, + probe_elements_null_data, + probe_array_offset, probe_array_size) + : _process_with_hash_table( + state_ref, probe_elements_data, probe_elements_null_data, + probe_array_offset, probe_array_size); + } + + if constexpr (NullableLeft || NullableRight) { + result_null_data[i] = 0; + } + } + if constexpr (NullableLeft || NullableRight) { + return NullableColumn::create(result_column, result_null_column); + } + return result_column; + } +}; + } // namespace starrocks diff --git a/be/src/util/bit_mask.h b/be/src/util/bit_mask.h index 9dc0ac6f17294..ae0f3cab24076 100644 --- a/be/src/util/bit_mask.h +++ b/be/src/util/bit_mask.h @@ -32,6 +32,14 @@ class BitMask { } void set_bit(size_t pos) { _bits[pos >> 3] |= (1 << (pos & 7)); } + // try to set bit in pos, if bit is already set, return false, otherwise return true + bool try_set_bit(size_t pos) { + if (is_bit_set(pos)) { + return false; + } + set_bit(pos); + return true; + } void clear_bit(size_t pos) { _bits[pos >> 3] &= ~(1 << (pos & 7)); } bool is_bit_set(size_t pos) { return (_bits[pos >> 3] & (1 << (pos & 7))) != 0; } diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index fb10e6697ddcb..1477f50eaaa22 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -323,6 +323,7 @@ public class FunctionSet { public static final String ARRAY_AVG = "array_avg"; public static final String ARRAY_CONTAINS = "array_contains"; public static final String ARRAY_CONTAINS_ALL = "array_contains_all"; + public static final String ARRAY_CONTAINS_SEQ = "array_contains_seq"; public static final String ARRAY_CUM_SUM = "array_cum_sum"; public static final String ARRAY_JOIN = "array_join"; @@ -711,6 +712,8 @@ public class FunctionSet { .add(ARRAY_CONCAT) .add(ARRAY_SLICE) .add(ARRAY_CONTAINS) + .add(ARRAY_CONTAINS_ALL) + .add(ARRAY_CONTAINS_SEQ) .add(ARRAY_POSITION) .build(); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java index 6be095eba5039..78db40cbdd33b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java @@ -114,7 +114,9 @@ private static Type[] normalizeDecimalArgTypes(final Type[] argTypes, String fnN return Arrays.stream(argTypes).map(t -> commonType).toArray(Type[]::new); } - if (FunctionSet.ARRAYS_OVERLAP.equalsIgnoreCase(fnName)) { + if (FunctionSet.ARRAYS_OVERLAP.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_CONTAINS_ALL.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_CONTAINS_SEQ.equalsIgnoreCase(fnName)) { Preconditions.checkState(argTypes.length == 2); Type[] childTypes = Arrays.stream(argTypes).map(a -> { if (a.isArrayType()) { @@ -530,7 +532,9 @@ private static Function getArrayDecimalFunction(Function fn, Type[] argumentType newFn.setRetType(new ArrayType(triple.returnType)); return newFn; } - case FunctionSet.ARRAYS_OVERLAP: { + case FunctionSet.ARRAYS_OVERLAP: + case FunctionSet.ARRAY_CONTAINS_ALL: + case FunctionSet.ARRAY_CONTAINS_SEQ: { newFn.setArgsType(argumentTypes); return newFn; } diff --git a/gensrc/script/functions.py b/gensrc/script/functions.py index a5b47665ccc51..b6840b9a70b12 100644 --- a/gensrc/script/functions.py +++ b/gensrc/script/functions.py @@ -1206,8 +1206,44 @@ # reserve 150281 [150282, 'array_contains_all', True, False, 'BOOLEAN', ['ANY_ARRAY', 'ANY_ARRAY'], 'ArrayFunctions::array_contains_all'], + [15028201, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028202, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_TINYINT', 'ARRAY_TINYINT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028203, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_SMALLINT', 'ARRAY_SMALLINT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028204, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_INT', 'ARRAY_INT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028205, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_BIGINT', 'ARRAY_BIGINT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028206, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_LARGEINT', 'ARRAY_LARGEINT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028207, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028208, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028209, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028210, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028211, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028212, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028213, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_VARCHAR', 'ARRAY_VARCHAR'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028214, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DATE', 'ARRAY_DATE'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + [15028215, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], 'ArrayFunctions::array_contains_all_specific', 'ArrayFunctions::array_contains_all_specific_prepare', 'ArrayFunctions::array_contains_all_specific_close'], + + + + # TODO: sepecific type [150283, 'array_contains_seq', True, False, 'BOOLEAN', ['ANY_ARRAY', 'ANY_ARRAY'], 'ArrayFunctions::array_contains_seq'], + [15028301, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028302, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_TINYINT', 'ARRAY_TINYINT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028303, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_SMALLINT', 'ARRAY_SMALLINT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028304, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_INT', 'ARRAY_INT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028305, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_BIGINT', 'ARRAY_BIGINT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028306, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_LARGEINT', 'ARRAY_LARGEINT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028307, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028308, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028309, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028310, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028311, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028312, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028313, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_VARCHAR', 'ARRAY_VARCHAR'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028314, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DATE', 'ARRAY_DATE'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + [15028315, 'array_contains_seq', True, False, 'BOOLEAN', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], 'ArrayFunctions::array_contains_seq_specific', 'ArrayFunctions::array_contains_seq_specific_prepare', 'ArrayFunctions::array_contains_seq_specific_close'], + + [150300, 'array_filter', True, False, 'ANY_ARRAY', ['ANY_ARRAY', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_filter'], [150301, 'all_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::all_match'], [150302, 'any_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::any_match'], diff --git a/test/sql/test_array_fn/R/test_array_contains b/test/sql/test_array_fn/R/test_array_contains index 34623ba442bac..1f09f1c2ad0c4 100644 --- a/test/sql/test_array_fn/R/test_array_contains +++ b/test/sql/test_array_fn/R/test_array_contains @@ -246,4 +246,416 @@ select array_position(v2, v1) from t; select array_position(v3, v2) from t; -- result: 1 +-- !result +-- name: test_array_contains_all_and_seq +CREATE TABLE t ( + k bigint(20) NOT NULL, + arr_0 array NOT NULL, + arr_1 array, + arr_2 array +) ENGINE=OLAP +DUPLICATE KEY(`k`) +DISTRIBUTED BY RANDOM BUCKETS 1 +PROPERTIES ( +"replication_num" = "1" +); +-- result: +-- !result +insert into t values +(1, [1,2,3], [1,2], [1]), +(2, [1,2,null], [1,null], [null]), +(3, [1,2,null],[3],[3]), +(4, [1,2,null], null, [1,2,null]), +(5, [1,2,null], [1,2,null], null), +(6, [1,2,3],[],[]), +(7, [null,null], [null,null,null], [null,null]), +(8, [1,1,1,1,1,2], [1,2], [1]), +(9, [1,1,1,1,1,null,2],[1,null,2],[null,2]); +-- result: +-- !result +select array_contains_all(arr_0, arr_1) from t order by k; +-- result: +1 +1 +0 +None +1 +1 +1 +1 +1 +-- !result +select array_contains_all(arr_1, arr_0) from t order by k; +-- result: +0 +0 +0 +None +1 +0 +1 +1 +1 +-- !result +select array_contains_all(arr_0, arr_2) from t order by k; +-- result: +1 +1 +0 +1 +None +1 +1 +1 +1 +-- !result +select array_contains_all(arr_2, arr_0) from t order by k; +-- result: +0 +0 +0 +1 +None +0 +1 +0 +0 +-- !result +select array_contains_all(arr_1, arr_2) from t order by k; +-- result: +1 +1 +1 +None +None +1 +1 +1 +1 +-- !result +select array_contains_all(arr_2, arr_1) from t order by k; +-- result: +0 +0 +1 +None +None +1 +1 +0 +0 +-- !result +select array_contains_all([1,2,3,4], arr_0) from t order by k; +-- result: +1 +0 +0 +0 +0 +1 +0 +1 +0 +-- !result +select array_contains_all([1,2,3,4], arr_1) from t order by k; +-- result: +1 +0 +1 +None +0 +1 +0 +1 +0 +-- !result +select array_contains_all([1,2,3,4,null], arr_1) from t order by k; +-- result: +1 +1 +1 +None +1 +1 +1 +1 +1 +-- !result +select array_contains_all(arr_0, [1,null]) from t order by k; +-- result: +0 +1 +1 +1 +1 +0 +0 +0 +1 +-- !result +select array_contains_all(arr_0, []) from t order by k; +-- result: +1 +1 +1 +1 +1 +1 +1 +1 +1 +-- !result +select array_contains_all(null, arr_0) from t order by k; +-- result: +None +None +None +None +None +None +None +None +None +-- !result +select array_contains_all(arr_1, null) from t order by k; +-- result: +None +None +None +None +None +None +None +None +None +-- !result +set @arr0 = array_repeat("abcdefg", 10000); +-- result: +-- !result +set @arr1 = array_repeat("abcdef", 100000); +-- result: +-- !result +select array_contains_all(@arr0, @arr1); +-- result: +0 +-- !result +set @arr0 = array_generate(10000); +-- result: +-- !result +set @arr1 = array_generate(20000); +-- result: +-- !result +select array_contains_all(@arr0, @arr1); +-- result: +0 +-- !result +select array_contains_all(@arr1, @arr0); +-- result: +1 +-- !result +select array_contains_seq(arr_0, arr_1) from t order by k; +-- result: +1 +0 +0 +None +1 +1 +0 +1 +1 +-- !result +select array_contains_seq(arr_1, arr_0) from t order by k; +-- result: +0 +0 +0 +None +1 +0 +1 +0 +0 +-- !result +select array_contains_seq(arr_0, arr_2) from t order by k; +-- result: +1 +1 +0 +1 +None +1 +1 +1 +1 +-- !result +select array_contains_seq(arr_2, arr_0) from t order by k; +-- result: +0 +0 +0 +1 +None +0 +1 +0 +0 +-- !result +select array_contains_seq(arr_1, arr_2) from t order by k; +-- result: +1 +1 +1 +None +None +1 +1 +1 +1 +-- !result +select array_contains_seq(arr_2, arr_1) from t order by k; +-- result: +0 +0 +1 +None +None +1 +0 +0 +0 +-- !result +select array_contains_seq([1,2,3,4], arr_0) from t order by k; +-- result: +1 +0 +0 +0 +0 +1 +0 +0 +0 +-- !result +select array_contains_seq([1,2,3,4], arr_1) from t order by k; +-- result: +1 +0 +1 +None +0 +1 +0 +1 +0 +-- !result +select array_contains_seq([1,2,3,4,null], arr_1) from t order by k; +-- result: +1 +0 +1 +None +0 +1 +0 +1 +0 +-- !result +select array_contains_seq(arr_0, [1,null]) from t order by k; +-- result: +0 +0 +0 +0 +0 +0 +0 +0 +1 +-- !result +select array_contains_seq(arr_0, []) from t order by k; +-- result: +1 +1 +1 +1 +1 +1 +1 +1 +1 +-- !result +select array_contains_seq(null, arr_0) from t order by k; +-- result: +None +None +None +None +None +None +None +None +None +-- !result +select array_contains_seq(arr_1, null) from t order by k; +-- result: +None +None +None +None +None +None +None +None +None +-- !result +select array_contains_seq([1,1,2,3],[1,1]); +-- result: +1 +-- !result +select array_contains_seq([1,1,2,3],[1,2]); +-- result: +1 +-- !result +select array_contains_seq([1,1,2,3],[1,3]); +-- result: +0 +-- !result +select array_contains_seq([1,1,2,3],[2,3]); +-- result: +1 +-- !result +select array_contains_seq([1,1,2,3],[1,1,2]); +-- result: +1 +-- !result +select array_contains_seq([null,null,1,2],[null]); +-- result: +1 +-- !result +select array_contains_seq([null,null,1,2],[null,null]); +-- result: +1 +-- !result +select array_contains_seq([null,null,1,2],[null,1]); +-- result: +1 +-- !result +select array_contains_seq([null,null,1,2],[null,null,1]); +-- result: +1 +-- !result +select array_contains_seq([null,null,1,2],[null,1,2]); +-- result: +1 +-- !result +set @arr0 = array_append(array_repeat(1, 10000), 2); +-- result: +-- !result +set @arr1 = array_append(array_repeat(1, 5000), 2); +-- result: +-- !result +select array_contains_seq(@arr0, @arr1); +-- result: +1 -- !result \ No newline at end of file diff --git a/test/sql/test_array_fn/T/test_array_contains b/test/sql/test_array_fn/T/test_array_contains index 4b8799df1bcea..dbfee11d6a541 100644 --- a/test/sql/test_array_fn/T/test_array_contains +++ b/test/sql/test_array_fn/T/test_array_contains @@ -67,4 +67,77 @@ select array_position(v1, 1.1) from t; select array_position(v2, [1.1]) from t; select array_position(v3, [[1.1]]) from t; select array_position(v2, v1) from t; -select array_position(v3, v2) from t; \ No newline at end of file +select array_position(v3, v2) from t; + +-- name: test_array_contains_all_and_seq +CREATE TABLE t ( + k bigint(20) NOT NULL, + arr_0 array NOT NULL, + arr_1 array, + arr_2 array +) ENGINE=OLAP +DUPLICATE KEY(`k`) +DISTRIBUTED BY RANDOM BUCKETS 1 +PROPERTIES ( +"replication_num" = "1" +); +insert into t values +(1, [1,2,3], [1,2], [1]), +(2, [1,2,null], [1,null], [null]), +(3, [1,2,null],[3],[3]), +(4, [1,2,null], null, [1,2,null]), +(5, [1,2,null], [1,2,null], null), +(6, [1,2,3],[],[]), +(7, [null,null], [null,null,null], [null,null]), +(8, [1,1,1,1,1,2], [1,2], [1]), +(9, [1,1,1,1,1,null,2],[1,null,2],[null,2]); +select array_contains_all(arr_0, arr_1) from t order by k; +select array_contains_all(arr_1, arr_0) from t order by k; +select array_contains_all(arr_0, arr_2) from t order by k; +select array_contains_all(arr_2, arr_0) from t order by k; +select array_contains_all(arr_1, arr_2) from t order by k; +select array_contains_all(arr_2, arr_1) from t order by k; +select array_contains_all([1,2,3,4], arr_0) from t order by k; +select array_contains_all([1,2,3,4], arr_1) from t order by k; +select array_contains_all([1,2,3,4,null], arr_1) from t order by k; +select array_contains_all(arr_0, [1,null]) from t order by k; +select array_contains_all(arr_0, []) from t order by k; +select array_contains_all(null, arr_0) from t order by k; +select array_contains_all(arr_1, null) from t order by k; + +set @arr0 = array_repeat("abcdefg", 10000); +set @arr1 = array_repeat("abcdef", 100000); +select array_contains_all(@arr0, @arr1); +set @arr0 = array_generate(10000); +set @arr1 = array_generate(20000); +select array_contains_all(@arr0, @arr1); +select array_contains_all(@arr1, @arr0); + +select array_contains_seq(arr_0, arr_1) from t order by k; +select array_contains_seq(arr_1, arr_0) from t order by k; +select array_contains_seq(arr_0, arr_2) from t order by k; +select array_contains_seq(arr_2, arr_0) from t order by k; +select array_contains_seq(arr_1, arr_2) from t order by k; +select array_contains_seq(arr_2, arr_1) from t order by k; +select array_contains_seq([1,2,3,4], arr_0) from t order by k; +select array_contains_seq([1,2,3,4], arr_1) from t order by k; +select array_contains_seq([1,2,3,4,null], arr_1) from t order by k; +select array_contains_seq(arr_0, [1,null]) from t order by k; +select array_contains_seq(arr_0, []) from t order by k; +select array_contains_seq(null, arr_0) from t order by k; +select array_contains_seq(arr_1, null) from t order by k; + +select array_contains_seq([1,1,2,3],[1,1]); +select array_contains_seq([1,1,2,3],[1,2]); +select array_contains_seq([1,1,2,3],[1,3]); +select array_contains_seq([1,1,2,3],[2,3]); +select array_contains_seq([1,1,2,3],[1,1,2]); +select array_contains_seq([null,null,1,2],[null]); +select array_contains_seq([null,null,1,2],[null,null]); +select array_contains_seq([null,null,1,2],[null,1]); +select array_contains_seq([null,null,1,2],[null,null,1]); +select array_contains_seq([null,null,1,2],[null,1,2]); + +set @arr0 = array_append(array_repeat(1, 10000), 2); +set @arr1 = array_append(array_repeat(1, 5000), 2); +select array_contains_seq(@arr0, @arr1);