From 0624378160b1d19de12a29b1374dea1930c9faaa Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Wed, 19 Jul 2023 00:13:45 +0300 Subject: [PATCH] feat: array functions treat an array as an element (#6986) --- .../tests/sqllogictests/test_files/array.slt | 196 +++++++++++++----- .../physical-expr/src/array_expressions.rs | 59 +++++- 2 files changed, 205 insertions(+), 50 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index f0f50ccc9340..1e9b32414bba 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -55,6 +55,13 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE nested_arrays +AS VALUES + (make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), make_array(7, 8, 9), 2, make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]])), + (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]])) +; + statement ok CREATE TABLE arrays_values AS VALUES @@ -100,6 +107,13 @@ NULL [13.3, 14.4, 15.5] [a, m, e, t] [[11, 12], [13, 14]] NULL [,] [[15, 16], [, 18]] [16.6, 17.7, 18.8] NULL +# nested_arrays table +query ??I? +select column1, column2, column3, column4 from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [7, 8, 9] 2 [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [10, 11, 12] 3 [[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]] + # values table query IIIRT select a, b, c, d, e from values; @@ -292,7 +306,13 @@ select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3 ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] -# array_append with columns +# array_append scalar function #4 (element is list) +query ??? +select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; ---- @@ -305,7 +325,14 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] -# array_append with columns and scalars +# array_append with columns #2 (element is list) +query ? +select array_append(column1, column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + +# array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; ---- @@ -317,6 +344,13 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +# array_append with columns and scalars #2 +query ?? +select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend # array_prepend scalar function #1 @@ -337,7 +371,13 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] -# array_prepend with columns +# array_prepend scalar function #4 (element is list) +query ??? +select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; ---- @@ -350,7 +390,14 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] -# array_prepend with columns and scalars +# array_prepend with columns #2 (element is list) +query ? +select array_prepend(column2, column1) from nested_arrays; +---- +[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + +# array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; ---- @@ -362,6 +409,13 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +# array_prepend with columns and scalars #2 (element is list) +query ?? +select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; +---- +[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_fill # array_fill scalar function #1 @@ -473,19 +527,6 @@ select array_concat(make_array(column2), make_array(column3)) from arrays_values # array_concat column-wise #4 query ? -select array_concat(column1, column2) from arrays_values; ----- -[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] -[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] -[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] -[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] -[44] -[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] -[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] -[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] - -# array_concat column-wise #5 -query ? select array_concat(make_array(column2), make_array(0)) from arrays_values; ---- [1, 0] @@ -497,7 +538,7 @@ select array_concat(make_array(column2), make_array(0)) from arrays_values; [55, 0] [66, 0] -# array_concat column-wise #6 +# array_concat column-wise #5 query ??? select array_concat(column1, column1), array_concat(column2, column2), array_concat(column3, column3) from arrays; ---- @@ -509,7 +550,7 @@ NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, m, e, t] [[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,] [[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL -# array_concat column-wise #7 +# array_concat column-wise #6 query ?? select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays; ---- @@ -521,7 +562,7 @@ select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), ar [[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3] [[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3] -# array_concat column-wise #8 +# array_concat column-wise #7 query ? select array_concat(column3, make_array('.', '.', '.')) from arrays; ---- @@ -543,7 +584,7 @@ select array_concat(column3, make_array('.', '.', '.')) from arrays; # [11, 12] NULL NULL NULL # NULL NULL NULL NULL -# array_concat column-wise #9 (1D + 1D) +# array_concat column-wise #8 (1D + 1D) query ? select array_concat(column1, column2) from arrays_values_v2; ---- @@ -554,28 +595,36 @@ select array_concat(column1, column2) from arrays_values_v2; [11, 12] NULL -# TODO: Concat columns with different dimensions fails -# array_concat column-wise #10 (1D + 2D) -# query error DataFusion error: Arrow error: Invalid argument error: column types must match schema types, expected List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) at column index 0 -# select array_concat(make_array(column3), column4) from arrays_values_v2; +# array_concat column-wise #9 (2D + 1D) +query ? +select array_concat(column4, make_array(column3)) from arrays_values_v2; +---- +[[30, 40, 50], [12]] +[[, , 60], [13]] +[[70, , ], [14]] +[[]] +[[]] +[[]] + +# array_concat column-wise #10 (3D + 2D + 1D) +query ? +select array_concat(column4, column1, column2) from nested_arrays; +---- +[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]], [[7, 8, 9]]] +[[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]], [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]], [[10, 11, 12]]] -# array_concat column-wise #11 (1D + Integers) +# array_concat column-wise #11 (2D + 1D) query ? -select array_concat(column2, column3) from arrays_values_v2; +select array_concat(column4, column1) from arrays_values_v2; ---- -[4, 5, , 12] -[7, , 8, 13] -[14] -[, 21, ] +[[30, 40, 50], [, 2, 3]] +[[, , 60], ] +[[70, , ], [9, , 10]] +[[, 1]] +[[11, 12]] [] -[] - -# TODO: Panic at 'range end index 3 out of range for slice of length 2' -# array_concat column-wise #12 (2D + 1D) -# query -# select array_concat(column4, column1) from arrays_values_v2; -# array_concat column-wise #13 (1D + 1D + 1D) +# array_concat column-wise #12 (1D + 1D + 1D) query ? select array_concat(make_array(column3), column1, column2) from arrays_values_v2; ---- @@ -594,13 +643,25 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 -# array_position scalar function #2 +# array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ---- 4 5 2 -# array_position with columns +# array_position scalar function #3 (element is list) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +---- +2 2 + +# array_position scalar function #4 (element in list; with optional argument) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], 3), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 3); +---- +4 3 + +# array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; ---- @@ -609,24 +670,44 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 -# array_position with columns and scalars +# array_position with columns #2 (element is list) query II -select array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; +select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; ---- -3 NULL -NULL NULL -NULL NULL -NULL NULL +3 3 +2 5 + +# array_position with columns and scalars #1 +query III +select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# array_position with columns and scalars #2 (element is list) +query III +select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; +---- +NULL 6 4 +NULL 1 NULL ## array_positions -# array_positions scalar function +# array_positions scalar function #1 query ??? select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] -# array_positions with columns +# array_positions scalar function #2 +query ? +select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); +---- +[2, 4] + +# array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; ---- @@ -635,7 +716,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] -# array_positions with columns and scalars +# array_positions with columns #2 (element is list) +query ? +select array_positions(column1, column2) from nested_arrays; +---- +[3] +[2, 5] + +# array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; ---- @@ -644,6 +732,13 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +# array_positions with columns and scalars #2 (element is list) +query ?? +select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace # array_replace scalar function @@ -1053,6 +1148,9 @@ select make_array(f0) from fixed_size_list_array statement ok drop table values; +statement ok +drop table nested_arrays; + statement ok drop table arrays; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 104d49e1c876..b16432b50531 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -410,6 +410,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result { let element = &args[1]; let res = match (arr.value_type(), element.data_type()) { + (DataType::List(_), DataType::List(_)) => concat_internal(args)?, (DataType::Utf8, DataType::Utf8) => append!(arr, element, StringArray), (DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, element, LargeStringArray), (DataType::Boolean, DataType::Boolean) => append!(arr, element, BooleanArray), @@ -499,6 +500,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[1])?; let res = match (arr.value_type(), element.data_type()) { + (DataType::List(_), DataType::List(_)) => concat_internal(args)?, (DataType::Utf8, DataType::Utf8) => prepend!(arr, element, StringArray), (DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, element, LargeStringArray), (DataType::Boolean, DataType::Boolean) => prepend!(arr, element, BooleanArray), @@ -543,7 +545,18 @@ fn align_array_dimensions(args: Vec) -> Result> { let mut aligned_array = array.clone(); for _ in 0..(max_ndim - ndim) { let data_type = aligned_array.as_ref().data_type().clone(); - aligned_array = array_array(&[aligned_array], data_type)?; + let offsets: Vec = + (0..downcast_arg!(aligned_array, ListArray).offsets().len()) + .map(|i| i as i32) + .collect(); + let field = Arc::new(Field::new("item", data_type, true)); + + aligned_array = Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(aligned_array.clone()), + None, + )?) } Ok(aligned_array) } else { @@ -761,6 +774,7 @@ pub fn array_position(args: &[ArrayRef]) -> Result { let res = match arr.data_type() { DataType::List(field) => match field.data_type() { + DataType::List(_) => position!(arr, element, index, ListArray), DataType::Utf8 => position!(arr, element, index, StringArray), DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), DataType::Boolean => position!(arr, element, index, BooleanArray), @@ -846,6 +860,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let res = match arr.data_type() { DataType::List(field) => match field.data_type() { + DataType::List(_) => positions!(arr, element, ListArray), DataType::Utf8 => positions!(arr, element, StringArray), DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), DataType::Boolean => positions!(arr, element, BooleanArray), @@ -1617,6 +1632,48 @@ mod tests { ); } + #[test] + fn test_nested_array_concat() { + // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_concat(&[list_array.clone(), list_array.clone()]) + .expect("failed to initialize function array_concat"); + let result = + as_list_array(&arr).expect("failed to initialize function array_concat"); + + assert_eq!( + &[1, 2, 3, 4, 1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4]] + let list_nested_array = return_nested_array().into_array(1); + let list_array = return_array().into_array(1); + let arr = array_concat(&[list_nested_array, list_array]) + .expect("failed to initialize function array_concat"); + let result = + as_list_array(&arr).expect("failed to initialize function array_concat"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(2) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + #[test] fn test_array_fill() { // array_fill(4, [5]) = [4, 4, 4, 4, 4]