diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 5fad7433a57f..3e16b045f465 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -541,7 +541,6 @@ mod test { } #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_make_error_parse_input() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err(); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 94c98a0cd296..20c90a43e979 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -645,17 +645,17 @@ fn replace_nulls_with_coerced_types( } } +// Coerce array arguments types for array functions +// Convert type or return error for incompatible types in this step fn coerce_array_args( fun: &BuiltinScalarFunction, expressions: Vec, schema: &DFSchema, ) -> Result> { - // Array function with indices don't need coercion for the indices - if *fun == BuiltinScalarFunction::ArrayElement - || *fun == BuiltinScalarFunction::ArraySlice - || *fun == BuiltinScalarFunction::ArrayRepeat - || *fun == BuiltinScalarFunction::ArrayPosition - || *fun == BuiltinScalarFunction::ArrayRemoveN + if *fun != BuiltinScalarFunction::MakeArray + && *fun != BuiltinScalarFunction::ArrayAppend + && *fun != BuiltinScalarFunction::ArrayPrepend + && *fun != BuiltinScalarFunction::ArrayConcat { return Ok(expressions); } @@ -665,6 +665,20 @@ fn coerce_array_args( .map(|e| e.get_type(schema)) .collect::>>()?; + // Check dimensions and align dimensions + // TODO: Move align array dimensions here. Function used in concat, append, prepend. + if *fun == BuiltinScalarFunction::ArrayConcat { + for expr_type in input_types.iter() { + if let DataType::List(_) = expr_type { + continue; + } else { + return plan_err!( + "The array_concat function can only accept list as the args" + ); + } + } + } + // Get base type for each input type // e.g List[Int64] -> Int64 // List[List[Int64]] -> Int64 diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index edef77c51c31..49ed2a74dff6 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -117,7 +117,6 @@ mod tests { } #[test] - #[allow(clippy::single_range_in_vec_init)] fn test_cume_dist() -> Result<()> { let r = cume_dist("arr".into()); diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 9bc36728f46e..e24bfee3c060 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -227,7 +227,6 @@ mod tests { test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected) } - #[allow(clippy::single_range_in_vec_init)] fn test_without_rank(expr: &Rank, expected: Vec) -> Result<()> { test_i32_result(expr, vec![0..8], expected) } @@ -276,7 +275,6 @@ mod tests { } #[test] - #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { let r = percent_rank("arr".into()); diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 2254a91c0b78..af696050b1cd 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -166,12 +166,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let data_types = values.iter().map(|e| e.get_datatype()).collect::>(); let seen_types: HashSet = values.iter().map(|e| e.get_datatype()).collect(); - let coerced_type = data_types - .iter() - .skip(1) - .fold(data_types[0].clone(), |acc, d| { - comparison_coercion(&acc, d).unwrap_or(acc) - }); match seen_types.len() { 0 => Ok(lit(ScalarValue::new_list(None, DataType::Utf8))), @@ -180,6 +174,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(lit(ScalarValue::new_list(Some(values), data_type))) } _ => { + let coerced_type = data_types + .iter() + .skip(1) + .fold(data_types[0].clone(), |acc, d| { + comparison_coercion(&acc, d).unwrap_or(acc) + }); let values = values .iter() .map(|e| { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4efeff5f2fc6..7968090f15f1 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1241,7 +1241,9 @@ select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_w ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) # array_concat error -query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. +query error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) +caused by +Error during planning: The array_concat function can only accept list as the args\. select array_concat(1, 2); # array_concat scalar function #1 @@ -1293,21 +1295,16 @@ select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(n ---- [[1, 2], [3, 4], []] -query ? +query error DataFusion error: type_coercion\ncaused by\nError during planning: The array_concat function can only accept list as the args select array_concat(make_array(make_array(1, 2), make_array(3, 4)), null); ----- -[[1, 2], [3, 4]] query ? select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(5)); ---- [[1, 2], [3, 4], [5]] -# TODO: Get error for this query -query ? +query error DataFusion error: type_coercion\ncaused by\nError during planning: The array_concat function can only accept list as the args select array_concat(make_array(make_array(1, 2), make_array(3, 4)), 5); ----- -[[1, 2], [3, 4]] # array_concat scalar function #8 (with empty arrays) query ? @@ -1815,7 +1812,7 @@ select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12 query TTT select array_to_string(['h', 'e', 'l', 'l', 'o'], ','), array_to_string([1, 2, 3, 4, 5], '-'), array_to_string([1.0, 2.0, 3.0], '|'); ---- -h,e,l,l,o 1-2-3-4-5 1.0|2.0|3.0 +h,e,l,l,o 1-2-3-4-5 1|2|3 # array_to_string scalar function #2 query TTT @@ -1833,31 +1830,31 @@ select array_to_string(make_array(), ',') query TTT select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); ---- -h,e,l,l,o 1-2-3-4-5 1.0|2.0|3.0 +h,e,l,l,o 1-2-3-4-5 1|2|3 # array_join scalar function #5 (function alias `array_to_string`) query TTT select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); ---- -h,e,l,l,o 1-2-3-4-5 1.0|2.0|3.0 +h,e,l,l,o 1-2-3-4-5 1|2|3 # list_join scalar function #6 (function alias `list_join`) query TTT select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); ---- -h,e,l,l,o 1-2-3-4-5 1.0|2.0|3.0 +h,e,l,l,o 1-2-3-4-5 1|2|3 # array_to_string scalar function with nulls #1 query TTT select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); ---- -h,l,o 1-3-5 2.0|3.0 +h,l,o 1-3-5 2|3 # array_to_string scalar function with nulls #2 query TTT select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); ---- -h,-,-,-,o nil-2-nil-4-5 1.0|0|3.0 +h,-,-,-,o nil-2-nil-4-5 1|0|3 # array_to_string with columns #1