Skip to content

Commit

Permalink
[Enhancement][Refactor] optimize array_contains_all/array_contains_se…
Browse files Browse the repository at this point in the history
…q function (StarRocks#51701)

Signed-off-by: silverbullet233 <[email protected]>
  • Loading branch information
silverbullet233 committed Nov 15, 2024
1 parent c789886 commit 755f8d9
Show file tree
Hide file tree
Showing 10 changed files with 817 additions and 7 deletions.
11 changes: 11 additions & 0 deletions be/src/column/column_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,17 @@ ColumnPtr ColumnHelper::convert_time_column_from_double_to_str(const ColumnPtr&
return res;
}

std::tuple<UInt32Column::Ptr, ColumnPtr, NullColumnPtr> ColumnHelper::unpack_array_column(const ColumnPtr& column) {
DCHECK(!column->is_nullable() && !column->is_constant());
DCHECK(column->is_array());

const ArrayColumn* array_column = down_cast<ArrayColumn*>(column.get());
auto elements_column = down_cast<NullableColumn*>(array_column->elements_column().get())->data_column();
auto null_column = down_cast<NullableColumn*>(array_column->elements_column().get())->null_column();
auto offsets_column = array_column->offsets_column();
return {offsets_column, elements_column, null_column};
}

template <class Ptr>
bool ChunkSliceTemplate<Ptr>::empty() const {
return !chunk || offset == chunk->num_rows();
Expand Down
2 changes: 2 additions & 0 deletions be/src/column/column_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ class ColumnHelper {
static ColumnPtr create_const_null_column(size_t chunk_size);

static ColumnPtr convert_time_column_from_double_to_str(const ColumnPtr& column);
// unpack array column, return offsets_column, elements_column, elements_null_column
static std::tuple<UInt32Column::Ptr, ColumnPtr, NullColumnPtr> unpack_array_column(const ColumnPtr& column);

static NullColumnPtr one_size_not_null_column;

Expand Down
17 changes: 17 additions & 0 deletions be/src/exprs/array_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,24 @@ class ArrayFunctions {
DEFINE_VECTORIZED_FN(array_cum_sum_double);

DEFINE_VECTORIZED_FN(array_contains_any);

DEFINE_VECTORIZED_FN(array_contains_all);

template <LogicalType LT>
static StatusOr<ColumnPtr> array_contains_all_specific(FunctionContext* context, const Columns& columns) {
return ArrayContainsAll<LT, false>::process(context, columns);
}
template <LogicalType LT>
static Status array_contains_all_specific_prepare(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
return ArrayContainsAll<LT, false>::prepare(context, scope);
}
template <LogicalType LT>
static Status array_contains_all_specific_close(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
return ArrayContainsAll<LT, false>::close(context, scope);
}

DEFINE_VECTORIZED_FN(array_map);
DEFINE_VECTORIZED_FN(array_filter);
DEFINE_VECTORIZED_FN(all_match);
Expand Down
448 changes: 448 additions & 0 deletions be/src/exprs/array_functions.tpp

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions be/src/util/bit_mask.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdint>
#include <cstdlib>
#include <cstring>

namespace starrocks {

class BitMask {
public:
BitMask(size_t size) {
_size = (size + 7) / 8;
_bits = new uint8_t[_size];
memset(_bits, 0, _size);
}
~BitMask() {
if (_bits) {
delete[] _bits;
}
}

void set_bit(size_t pos) { _bits[pos >> 3] |= (1 << (pos & 7)); }
// try to set bit in pos, if bit is already set, return false, otherwise return true
bool try_set_bit(size_t pos) {
if (is_bit_set(pos)) {
return false;
}
set_bit(pos);
return true;
}
void clear_bit(size_t pos) { _bits[pos >> 3] &= ~(1 << (pos & 7)); }

bool is_bit_set(size_t pos) { return (_bits[pos >> 3] & (1 << (pos & 7))) != 0; }

bool all_bits_zero() const {
for (size_t i = 0; i < _size; i++) {
if (_bits[i] != 0) {
return false;
}
}
return true;
}

private:
size_t _size;
uint8_t* _bits;
};
} // namespace starrocks
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ public class FunctionSet {
public static final String ARRAY_AVG = "array_avg";
public static final String ARRAY_CONTAINS = "array_contains";
public static final String ARRAY_CONTAINS_ALL = "array_contains_all";
public static final String ARRAY_CONTAINS_SEQ = "array_contains_seq";
public static final String ARRAY_CUM_SUM = "array_cum_sum";

public static final String ARRAY_JOIN = "array_join";
Expand Down Expand Up @@ -699,6 +700,7 @@ public class FunctionSet {
.add(ARRAY_CONCAT)
.add(ARRAY_SLICE)
.add(ARRAY_CONTAINS)
.add(ARRAY_CONTAINS_ALL)
.add(ARRAY_POSITION)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ private static Type[] normalizeDecimalArgTypes(final Type[] argTypes, String fnN
return Arrays.stream(argTypes).map(t -> commonType).toArray(Type[]::new);
}

if (FunctionSet.ARRAYS_OVERLAP.equalsIgnoreCase(fnName)) {
if (FunctionSet.ARRAYS_OVERLAP.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_CONTAINS_ALL.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_CONTAINS_SEQ.equalsIgnoreCase(fnName)) {
Preconditions.checkState(argTypes.length == 2);
Type[] childTypes = Arrays.stream(argTypes).map(a -> {
if (a.isArrayType()) {
Expand Down Expand Up @@ -530,7 +532,9 @@ private static Function getArrayDecimalFunction(Function fn, Type[] argumentType
newFn.setRetType(new ArrayType(triple.returnType));
return newFn;
}
case FunctionSet.ARRAYS_OVERLAP: {
case FunctionSet.ARRAYS_OVERLAP:
case FunctionSet.ARRAY_CONTAINS_ALL:
case FunctionSet.ARRAY_CONTAINS_SEQ: {
newFn.setArgsType(argumentTypes);
return newFn;
}
Expand Down
26 changes: 21 additions & 5 deletions gensrc/script/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,11 +1177,27 @@
[150271, 'array_cum_sum', True, False, 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], 'ArrayFunctions::array_cum_sum_double'],

# reserve 150281
[150282, 'array_contains_all', True, False, 'BOOLEAN', ['ANY_ARRAY', 'ANY_ARRAY'], 'ArrayFunctions::array_contains_all'],

[150300, 'array_filter', True, False, 'ANY_ARRAY', ['ANY_ARRAY', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_filter'],
[150301, 'all_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::all_match'],
[150302, 'any_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::any_match'],
[150282, 'array_contains_all', True, False, 'BOOLEAN', ['ANY_ARRAY', 'ANY_ARRAY'],
'ArrayFunctions::array_contains_all'],
[15028201, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_contains_all_specific<TYPE_BOOLEAN>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_BOOLEAN>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_BOOLEAN>'],
[15028202, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_TINYINT', 'ARRAY_TINYINT'], 'ArrayFunctions::array_contains_all_specific<TYPE_TINYINT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_TINYINT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_TINYINT>'],
[15028203, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_SMALLINT', 'ARRAY_SMALLINT'], 'ArrayFunctions::array_contains_all_specific<TYPE_SMALLINT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_SMALLINT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_SMALLINT>'],
[15028204, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_INT', 'ARRAY_INT'], 'ArrayFunctions::array_contains_all_specific<TYPE_INT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_INT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_INT>'],
[15028205, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_BIGINT', 'ARRAY_BIGINT'], 'ArrayFunctions::array_contains_all_specific<TYPE_BIGINT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_BIGINT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_BIGINT>'],
[15028206, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_LARGEINT', 'ARRAY_LARGEINT'], 'ArrayFunctions::array_contains_all_specific<TYPE_LARGEINT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_LARGEINT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_LARGEINT>'],
[15028207, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], 'ArrayFunctions::array_contains_all_specific<TYPE_DECIMALV2>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DECIMALV2>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DECIMALV2>'],
[15028208, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], 'ArrayFunctions::array_contains_all_specific<TYPE_DECIMAL32>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DECIMAL32>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DECIMAL32>'],
[15028209, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], 'ArrayFunctions::array_contains_all_specific<TYPE_DECIMAL64>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DECIMAL64>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DECIMAL64>'],
[15028210, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], 'ArrayFunctions::array_contains_all_specific<TYPE_DECIMAL128>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DECIMAL128>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DECIMAL128>'],
[15028211, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], 'ArrayFunctions::array_contains_all_specific<TYPE_FLOAT>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_FLOAT>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_FLOAT>'],
[15028212, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], 'ArrayFunctions::array_contains_all_specific<TYPE_DOUBLE>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DOUBLE>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DOUBLE>'],
[15028213, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_VARCHAR', 'ARRAY_VARCHAR'], 'ArrayFunctions::array_contains_all_specific<TYPE_VARCHAR>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_VARCHAR>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_VARCHAR>'],
[15028214, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DATE', 'ARRAY_DATE'], 'ArrayFunctions::array_contains_all_specific<TYPE_DATE>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DATE>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DATE>'],
[15028215, 'array_contains_all', True, False, 'BOOLEAN', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], 'ArrayFunctions::array_contains_all_specific<TYPE_DATETIME>', 'ArrayFunctions::array_contains_all_specific_prepare<TYPE_DATETIME>', 'ArrayFunctions::array_contains_all_specific_close<TYPE_DATETIME>'],

[150300, 'array_filter', True, False, 'ANY_ARRAY', ['ANY_ARRAY', 'ARRAY_BOOLEAN'], 'ArrayFunctions::array_filter'],
[150301, 'all_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::all_match'],
[150302, 'any_match', True, False, 'BOOLEAN', ['ARRAY_BOOLEAN'], 'ArrayFunctions::any_match'],

[150311, 'array_sortby', True, False, 'ANY_ARRAY', ['ANY_ARRAY', 'ARRAY_BOOLEAN'],
'ArrayFunctions::array_sortby<TYPE_BOOLEAN>'],
Expand Down
206 changes: 206 additions & 0 deletions test/sql/test_array_fn/R/test_array_contains
Original file line number Diff line number Diff line change
Expand Up @@ -246,4 +246,210 @@ select array_position(v2, v1) from t;
select array_position(v3, v2) from t;
-- result:
1
-- !result
-- name: test_array_contains_all_and_seq
CREATE TABLE t (
k bigint(20) NOT NULL,
arr_0 array<bigint(20)> NOT NULL,
arr_1 array<bigint(20)>,
arr_2 array<bigint(20)>
) ENGINE=OLAP
DUPLICATE KEY(`k`)
DISTRIBUTED BY RANDOM BUCKETS 1
PROPERTIES (
"replication_num" = "1"
);
-- result:
-- !result
insert into t values
(1, [1,2,3], [1,2], [1]),
(2, [1,2,null], [1,null], [null]),
(3, [1,2,null],[3],[3]),
(4, [1,2,null], null, [1,2,null]),
(5, [1,2,null], [1,2,null], null),
(6, [1,2,3],[],[]),
(7, [null,null], [null,null,null], [null,null]),
(8, [1,1,1,1,1,2], [1,2], [1]),
(9, [1,1,1,1,1,null,2],[1,null,2],[null,2]);
-- result:
-- !result
select array_contains_all(arr_0, arr_1) from t order by k;
-- result:
1
1
0
None
1
1
1
1
1
-- !result
select array_contains_all(arr_1, arr_0) from t order by k;
-- result:
0
0
0
None
1
0
1
1
1
-- !result
select array_contains_all(arr_0, arr_2) from t order by k;
-- result:
1
1
0
1
None
1
1
1
1
-- !result
select array_contains_all(arr_2, arr_0) from t order by k;
-- result:
0
0
0
1
None
0
1
0
0
-- !result
select array_contains_all(arr_1, arr_2) from t order by k;
-- result:
1
1
1
None
None
1
1
1
1
-- !result
select array_contains_all(arr_2, arr_1) from t order by k;
-- result:
0
0
1
None
None
1
1
0
0
-- !result
select array_contains_all([1,2,3,4], arr_0) from t order by k;
-- result:
1
0
0
0
0
1
0
1
0
-- !result
select array_contains_all([1,2,3,4], arr_1) from t order by k;
-- result:
1
0
1
None
0
1
0
1
0
-- !result
select array_contains_all([1,2,3,4,null], arr_1) from t order by k;
-- result:
1
1
1
None
1
1
1
1
1
-- !result
select array_contains_all(arr_0, [1,null]) from t order by k;
-- result:
0
1
1
1
1
0
0
0
1
-- !result
select array_contains_all(arr_0, []) from t order by k;
-- result:
1
1
1
1
1
1
1
1
1
-- !result
select array_contains_all(null, arr_0) from t order by k;
-- result:
None
None
None
None
None
None
None
None
None
-- !result
select array_contains_all(arr_1, null) from t order by k;
-- result:
None
None
None
None
None
None
None
None
None
-- !result
set @arr0 = array_repeat("abcdefg", 10000);
-- result:
-- !result
set @arr1 = array_repeat("abcdef", 100000);
-- result:
-- !result
select array_contains_all(@arr0, @arr1);
-- result:
0
-- !result
set @arr0 = array_generate(10000);
-- result:
-- !result
set @arr1 = array_generate(20000);
-- result:
-- !result
select array_contains_all(@arr0, @arr1);
-- result:
0
-- !result
select array_contains_all(@arr1, @arr0);
-- result:
1
-- !result
Loading

0 comments on commit 755f8d9

Please sign in to comment.