diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e9030ebcc00f..cbf5d400bab5 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -599,12 +599,24 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Range => { Ok(List(Arc::new(Field::new("item", Int64, true)))) } - BuiltinScalarFunction::ArrayExcept => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayExcept => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index c0f6c67263a7..8968bcf2ea4e 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option) -> Result>>> fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); - if !args - .iter() - .all(|arg| arg.data_type().equals_datatype(data_type)) - { + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); return plan_err!("{name} received incompatible types: '{types:?}'."); } @@ -1512,19 +1512,29 @@ pub fn array_union(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), - (DataType::List(field_ref), DataType::List(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } } - (DataType::LargeList(field_ref), DataType::LargeList(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } + } } _ => { internal_err!( @@ -1919,55 +1929,66 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { assert_eq!(args.len(), 2); - let first_array = as_list_array(&args[0])?; - let second_array = as_list_array(&args[1])?; + let first_array = &args[0]; + let second_array = &args[1]; - if first_array.value_type() != second_array.value_type() { - return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); - } - let dt = first_array.value_type(); + match (first_array.data_type(), second_array.data_type()) { + (DataType::Null, _) => Ok(second_array.clone()), + (_, DataType::Null) => Ok(first_array.clone()), + _ => { + let first_array = as_list_array(&first_array)?; + let second_array = as_list_array(&second_array)?; - let mut offsets = vec![0]; - let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; - for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let values_set: HashSet<_> = l_values.iter().collect(); - let mut rows = Vec::with_capacity(r_values.num_rows()); - for r_val in r_values.iter().sorted().dedup() { - if values_set.contains(&r_val) { - rows.push(r_val); - } + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); } - let last_offset: i32 = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + rows.len() as i32); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { - Some(array) => array.clone(), - None => { - return internal_err!( - "array_intersect: failed to get array from rows" - ) + let dt = first_array.value_type(); + + let mut offsets = vec![0]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); } - }; - new_arrays.push(array); + } + + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = + new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) } } - - let field = Arc::new(Field::new("item", dt, true)); - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; - let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); - Ok(arr) } #[cfg(test)] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3a06fdb158f7..0f086bca6819 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,20 +16,20 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { - Expr::Literal(scalar) => { - values.push(scalar); + Expr::Literal(_) => { + values.push(value); + } + Expr::ScalarFunction(ref scalar_function) => { + if scalar_function.fun == BuiltinScalarFunction::MakeArray { + values.push(value); + } else { + return not_impl_err!( + "ScalarFunctions without MakeArray are not supported: {value}" + ); + } } _ => { return not_impl_err!( @@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let data_types: HashSet = - values.iter().map(|e| e.data_type()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0)))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].data_type(); - let arr = ScalarValue::new_list(&values, &data_type); - Ok(lit(ScalarValue::List(arr))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 4c2bad1c719e..a56e9a50f054 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1383,18 +1383,6 @@ fn select_interval_out_of_range() { ); } -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.strip_backtrace(), - "This feature is not implemented: Arrays with different types are not supported: " - ); -} - #[test] fn recursive_ctes() { let sql = " @@ -1411,16 +1399,6 @@ fn recursive_ctes() { ); } -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "This feature is not implemented: Arrays with elements other than literal are not supported: now()", - err.strip_backtrace() - ); -} - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index faad6feb3f33..7157be948914 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1396,7 +1396,7 @@ SELECT COUNT(DISTINCT c1) FROM test query ? SELECT ARRAY_AGG([]) ---- -[] +[[]] # array_agg_one query ? @@ -1419,7 +1419,7 @@ e 4 query ? SELECT ARRAY_AGG([]); ---- -[] +[[]] # array_agg_one query ? diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 61f190e7baf6..d33555509e6c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -265,6 +265,14 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; +query ? +select [1, true, null] +---- +[1, 1, ] + +query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() +SELECT [now()] + query TTT select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; ---- @@ -2014,7 +2022,7 @@ drop table arrays_with_repeating_elements_for_union; query ? select array_union([], []); ---- -NULL +[] # array_union scalar function #7 query ? @@ -2032,7 +2040,7 @@ select array_union([null], [null]); query ? select array_union(null, []); ---- -NULL +[] # array_union scalar function #10 query ? @@ -2687,6 +2695,26 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect([], null); +---- +[] + +query ? +select array_intersect(null, []); +---- +[] + +query ? +select array_intersect(null, null); +---- +NULL + query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)), @@ -2842,6 +2870,26 @@ NULL statement ok drop table array_except_table_bool; +query ? +select array_except([], null); +---- +[] + +query ? +select array_except([], []); +---- +[] + +query ? +select array_except(null, []); +---- +NULL + +query ? +select array_except(null, null) +---- +NULL + ### Array operators tests