From 8012c4de3a7d20adb5af5298eb1dd2103c98a481 Mon Sep 17 00:00:00 2001
From: Jay Zhan <jayzhan211@gmail.com>
Date: Fri, 14 Jul 2023 19:10:10 +0800
Subject: [PATCH] Column support for array concat (#6879)

* first draft

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* use old concat func

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* merge main and add tests

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* support nulls

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add tests

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add more failed tests

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* update tests

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
---
 .../tests/sqllogictests/test_files/array.slt  | 164 ++++++++++++++++++
 .../physical-expr/src/array_expressions.rs    |  80 ++++++---
 2 files changed, 216 insertions(+), 28 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt
index cf20c14cac76..d9b3449dfe6e 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -68,6 +68,17 @@ AS VALUES
   (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL)
 ;
 
+statement ok
+CREATE TABLE arrays_values_v2
+AS VALUES
+  (make_array(NULL, 2, 3), make_array(4, 5, NULL), 12, make_array([30, 40, 50])),
+  (NULL, make_array(7, NULL, 8), 13, make_array(make_array(NULL,NULL,60))),
+  (make_array(9, NULL, 10), NULL, 14, make_array(make_array(70,NULL,NULL))),
+  (make_array(NULL, 1), make_array(NULL, 21), NULL, NULL),
+  (make_array(11, 12), NULL, NULL, NULL),
+  (NULL, NULL, NULL, NULL)
+;
+
 statement ok
 CREATE TABLE arrays_values_without_nulls
 AS VALUES
@@ -116,6 +127,16 @@ NULL 44 5 @
 [51, 52, , 54, 55, 56, 57, 58, 59, 60] 55 NULL ^
 [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] 66 7 NULL
 
+query ??I?
+select column1, column2, column3, column4 from arrays_values_v2;
+----
+[, 2, 3] [4, 5, ] 12 [[30, 40, 50]]
+NULL [7, , 8] 13 [[, , 60]]
+[9, , 10] NULL 14 [[70, , ]]
+[, 1] [, 21] NULL NULL
+[11, 12] NULL NULL NULL
+NULL NULL NULL NULL
+
 # arrays_values_without_nulls table
 query ?II
 select column1, column2, column3 from arrays_values_without_nulls;
@@ -423,6 +444,148 @@ select array_concat(make_array(10, 20), make_array([30, 40]), make_array([[50, 6
 ----
 [[[10, 20]], [[30, 40]], [[50, 60]]]
 
+# array_concat column-wise #1
+query ?
+select array_concat(column1, make_array(0)) from arrays_values_without_nulls;
+----
+[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0]
+[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0]
+[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 0]
+[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 0]
+
+# array_concat column-wise #2
+query ?
+select array_concat(column1, column1) from arrays_values_without_nulls;
+----
+[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
+[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
+[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 31, 32, 33, 34, 35, 26, 37, 38, 39, 40]
+
+# array_concat column-wise #3
+query ?
+select array_concat(make_array(column2), make_array(column3)) from arrays_values_without_nulls;
+----
+[1, 1]
+[12, 2]
+[23, 3]
+[34, 4]
+
+# array_concat column-wise #4
+query ?
+select array_concat(column1, column2) from arrays_values;
+----
+[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]
+[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12]
+[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23]
+[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34]
+[44]
+[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ]
+[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
+[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
+
+# array_concat column-wise #5
+query ?
+select array_concat(make_array(column2), make_array(0)) from arrays_values;
+----
+[1, 0]
+[12, 0]
+[23, 0]
+[34, 0]
+[44, 0]
+[, 0]
+[55, 0]
+[66, 0]
+
+# array_concat column-wise #6
+query ???
+select array_concat(column1, column1), array_concat(column2, column2), array_concat(column3, column3) from arrays;
+----
+[[, 2], [3, ], [, 2], [3, ]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3] [L, o, r, e, m, L, o, r, e, m]
+[[3, 4], [5, 6], [3, 4], [5, 6]] [, 5.5, 6.6, , 5.5, 6.6] [i, p, , u, m, i, p, , u, m]
+[[5, 6], [7, 8], [5, 6], [7, 8]] [7.7, 8.8, 9.9, 7.7, 8.8, 9.9] [d, , l, o, r, d, , l, o, r]
+[[7, ], [9, 10], [7, ], [9, 10]] [10.1, , 12.2, 10.1, , 12.2] [s, i, t, s, i, t]
+NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, m, e, t]
+[[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,]
+[[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL
+
+# array_concat column-wise #7
+query ??
+select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays;
+----
+[[, 2], [3, ], [1, 2], [3, 4]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3]
+[[3, 4], [5, 6], [1, 2], [3, 4]] [, 5.5, 6.6, 1.1, 2.2, 3.3]
+[[5, 6], [7, 8], [1, 2], [3, 4]] [7.7, 8.8, 9.9, 1.1, 2.2, 3.3]
+[[7, ], [9, 10], [1, 2], [3, 4]] [10.1, , 12.2, 1.1, 2.2, 3.3]
+[[1, 2], [3, 4]] [13.3, 14.4, 15.5, 1.1, 2.2, 3.3]
+[[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3]
+[[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3]
+
+# array_concat column-wise #8
+query ?
+select array_concat(column3, make_array('.', '.', '.')) from arrays;
+----
+[L, o, r, e, m, ., ., .]
+[i, p, , u, m, ., ., .]
+[d, , l, o, r, ., ., .]
+[s, i, t, ., ., .]
+[a, m, e, t, ., ., .]
+[,, ., ., .]
+[., ., .]
+
+# query ??I?
+# select column1, column2, column3, column4 from arrays_values_v2;
+# ----
+# [, 2, 3] [4, 5, ] 12 [[30, 40, 50]]
+# NULL [7, , 8] 13 [[, , 60]]
+# [9, , 10] NULL 14 [[70, , ]]
+# [, 1] [, 21] NULL NULL
+# [11, 12] NULL NULL NULL
+# NULL NULL NULL NULL
+
+# array_concat column-wise #9 (1D + 1D)
+query ?
+select array_concat(column1, column2) from arrays_values_v2;
+----
+[, 2, 3, 4, 5, ]
+[7, , 8]
+[9, , 10]
+[, 1, , 21]
+[11, 12]
+NULL
+
+# TODO: Concat columns with different dimensions fails
+# array_concat column-wise #10 (1D + 2D)
+# query error DataFusion error: Arrow error: Invalid argument error: column types must match schema types, expected List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) at column index 0
+# select array_concat(make_array(column3), column4) from arrays_values_v2;
+
+# array_concat column-wise #11 (1D + Integers)
+query ?
+select array_concat(column2, column3) from arrays_values_v2;
+----
+[4, 5, , 12]
+[7, , 8, 13]
+[14]
+[, 21, ]
+[]
+[]
+
+# TODO: Panic at 'range end index 3 out of range for slice of length 2'
+# array_concat column-wise #12 (2D + 1D)
+# query
+# select array_concat(column4, column1) from arrays_values_v2;
+
+# array_concat column-wise #13 (1D + 1D + 1D)
+query ?
+select array_concat(make_array(column3), column1, column2) from arrays_values_v2;
+----
+[12, , 2, 3, 4, 5, ]
+[13, 7, , 8]
+[14, 9, , 10]
+[, , 1, , 21]
+[, 11, 12]
+[]
+
 ## array_position
 
 # array_position scalar function #1
@@ -835,6 +998,7 @@ select make_array(f0) from fixed_size_list_array
 [[1, 2], [3, 4]]
 
 
+
 ### Delete tables
 
 
diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs
index a0fed8f90889..93c1626daffd 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -21,6 +21,7 @@ use arrow::array::*;
 use arrow::buffer::{Buffer, OffsetBuffer};
 use arrow::compute;
 use arrow::datatypes::{DataType, Field, UInt64Type};
+use arrow_buffer::NullBuffer;
 use core::any::type_name;
 use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array};
 use datafusion_common::ScalarValue;
@@ -554,42 +555,65 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
     aligned_args
 }
 
-/// Array_concat/Array_cat SQL function
-pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
-    match args[0].data_type() {
-        DataType::List(field) => match field.data_type() {
-            DataType::Null => array_concat(&args[1..]),
-            _ => {
-                let args = align_array_dimensions(args.to_vec())?;
+fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
+    let args = align_array_dimensions(args.to_vec())?;
 
-                let list_arrays = downcast_vec!(args, ListArray)
-                    .collect::<Result<Vec<&ListArray>>>()?;
+    let list_arrays =
+        downcast_vec!(args, ListArray).collect::<Result<Vec<&ListArray>>>()?;
 
-                let len: usize = list_arrays.iter().map(|a| a.values().len()).sum();
+    // Assume number of rows is the same for all arrays
+    let row_count = list_arrays[0].len();
+    let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
+    let array_data: Vec<_> = list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
+    let array_data: Vec<&ArrayData> = array_data.iter().collect();
 
-                let capacity =
-                    Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
-                let array_data: Vec<_> =
-                    list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
+    let mut mutable = MutableArrayData::with_capacities(array_data, true, capacity);
 
-                let array_data = array_data.iter().collect();
+    let mut array_lens = vec![0; row_count];
+    let mut null_bit_map: Vec<bool> = vec![true; row_count];
 
-                let mut mutable =
-                    MutableArrayData::with_capacities(array_data, false, capacity);
+    for (i, array_len) in array_lens.iter_mut().enumerate().take(row_count) {
+        let null_count = mutable.null_count();
+        for (j, a) in list_arrays.iter().enumerate() {
+            mutable.extend(j, i, i + 1);
+            *array_len += a.value_length(i);
+        }
 
-                for (i, a) in list_arrays.iter().enumerate() {
-                    mutable.extend(i, 0, a.len())
-                }
+        // This means all arrays are null
+        if mutable.null_count() == null_count + list_arrays.len() {
+            null_bit_map[i] = false;
+        }
+    }
 
-                let builder = mutable.into_builder();
-                let list = builder
-                    .len(1)
-                    .buffers(vec![Buffer::from_slice_ref([0, len as i32])])
-                    .build()
-                    .unwrap();
+    let mut buffer = BooleanBufferBuilder::new(row_count);
+    buffer.append_slice(null_bit_map.as_slice());
+    let nulls = Some(NullBuffer::from(buffer.finish()));
 
-                return Ok(Arc::new(arrow::array::make_array(list)));
-            }
+    let offsets: Vec<i32> = std::iter::once(0)
+        .chain(array_lens.iter().scan(0, |state, &x| {
+            *state += x;
+            Some(*state)
+        }))
+        .collect();
+
+    let builder = mutable.into_builder();
+
+    let list = builder
+        .len(row_count)
+        .buffers(vec![Buffer::from_vec(offsets)])
+        .nulls(nulls)
+        .build()?;
+
+    let list = arrow::array::make_array(list);
+    Ok(Arc::new(list))
+}
+
+/// Array_concat/Array_cat SQL function
+pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args[0].data_type() {
+        DataType::List(field) => match field.data_type() {
+            DataType::Null => array_concat(&args[1..]),
+            _ => concat_internal(args),
         },
         data_type => Err(DataFusionError::NotImplemented(format!(
             "Array is not type '{data_type:?}'."