diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6fa41495f0f70..454401b4715a4 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -962,7 +962,7 @@ impl BuiltinScalarFunction { Signature::any(2, self.volatility()) } BuiltinScalarFunction::ArrayHas => { - Signature::array_and_element(false, self.volatility()) + Signature::array_and_element(true, self.volatility()) } BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) @@ -977,18 +977,18 @@ impl BuiltinScalarFunction { Signature::variadic_any(self.volatility()) } BuiltinScalarFunction::ArrayPositions => { - Signature::array_and_element(false, self.volatility()) + Signature::array_and_element(true, self.volatility()) } BuiltinScalarFunction::ArrayPrepend => { Signature::element_and_array(false, self.volatility()) } BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => { - Signature::array_and_element(false, self.volatility()) + Signature::array_and_element(true, self.volatility()) } BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayRemoveAll => { - Signature::array_and_element(false, self.volatility()) + Signature::array_and_element(true, self.volatility()) } BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 2da608e5771b2..74768b4992bb5 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -173,8 +173,12 @@ impl ArrayFunctionSignature { }; // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) && !allow_null_coercion { - return Ok(vec![vec![]]); + if array_type.eq(&DataType::Null) { + if allow_null_coercion { + return Ok(vec![vec![array_type.clone(), elem_type.clone()]]); + } else { + return Ok(vec![vec![]]); + } } // We need to find the coerced base type, mainly for cases like: @@ -189,20 +193,21 @@ impl ArrayFunctionSignature { ) })?; - let array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, - &new_base_type, - ); + let new_array_type = + datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); - match array_type { + match new_array_type { DataType::List(ref field) | DataType::LargeList(ref field) | DataType::FixedSizeList(ref field, _) => { - let elem_type = field.data_type(); + let new_elem_type = field.data_type(); if is_append { - Ok(vec![vec![array_type.clone(), elem_type.clone()]]) + Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]]) } else { - Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) + Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]]) } } _ => Ok(vec![vec![]]), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 50c70ccfdb111..a49e1ef0a29a4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -86,7 +86,9 @@ fn compare_element_to_list( row_index: usize, eq: bool, ) -> Result { - if list_array_row.data_type() != element_array.data_type() { + if list_array_row.data_type() != element_array.data_type() + && !element_array.data_type().is_null() + { return exec_err!( "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", list_array_row.data_type(), @@ -1481,6 +1483,10 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } + DataType::Null => Ok(new_null_array( + &DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))), + 1, + )), array_type => { exec_err!("array_positions does not support type '{array_type:?}'.") } @@ -1613,6 +1619,10 @@ fn array_remove_internal( element_array: &ArrayRef, arr_n: Vec, ) -> Result { + if array.data_type().is_null() { + return Ok(array.clone()); + } + match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); @@ -2288,6 +2298,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result { DataType::LargeList(_) => { general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) } + DataType::Null => Ok(new_null_array(&DataType::Boolean, 1)), _ => exec_err!("array_has does not support type '{array_type:?}'."), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 789238b5723b4..a91c87c24404a 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2745,12 +2745,17 @@ NULL 1 NULL ## array_positions (aliases: `list_positions`) -# array_position with NULL (follow PostgreSQL) query ? select array_positions([1, 2, 3, 4, 5], null); ---- [] +# array_positions with NULL (follow PostgreSQL) +query ? +select array_positions(null, 1); +---- +NULL + # array_positions scalar function #1 query ??? select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); @@ -3874,6 +3879,13 @@ select ---- [1, , 3] [, 2.2, 3.3] [, bc] +# follow PostgreSQL behavior +query ? +select + array_remove(NULL, 1) +---- +NULL + query ?? select array_remove(make_array(1, null, 2), null), @@ -4034,6 +4046,11 @@ select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], ## array_remove_all (aliases: `list_removes`) # array_remove_all with NULL elements +query ? +select array_remove_all(NULL, 1); +---- +NULL + query ? select array_remove_all(make_array(1, 2, 2, 1, 1), NULL); ----