From a76beaab5f134908db328a137967d7fc750cfed8 Mon Sep 17 00:00:00 2001 From: veeupup Date: Mon, 20 Nov 2023 23:02:04 +0800 Subject: [PATCH] fix sql_array_literal Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 21 ++--- .../physical-expr/src/array_expressions.rs | 89 ++++++------------- datafusion/sql/src/expr/value.rs | 32 +++---- 3 files changed, 51 insertions(+), 91 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index b7cee92d9ac8b..cbf5d400bab58 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -599,19 +599,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayIntersect => { + BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { - (DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new( - Field::new("item", DataType::Null, true), - ))), - (dt, _) => Ok(dt), - } - } - BuiltinScalarFunction::ArrayUnion => { - match (input_expr_types[0].clone(), input_expr_types[1].clone()) { - (DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new( - Field::new("item", DataType::Null, true), - ))), + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), (dt, _) => Ok(dt), } } @@ -620,9 +611,9 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayExcept => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { - (DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new( - Field::new("item", DataType::Null, true), - ))), + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } (dt, _) => Ok(dt), } } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 97f1548888884..2debe73c6f2f2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option) -> Result>>> fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); - if !args - .iter() - .all(|arg| arg.data_type().equals_datatype(data_type)) - { + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); return plan_err!("{name} received incompatible types: '{types:?}'."); } @@ -580,21 +580,6 @@ pub fn array_except(args: &[ArrayRef]) -> Result { let array2 = &args[1]; match (array1.data_type(), array2.data_type()) { - (DataType::Null, DataType::Null) => { - // NullArray(1): means null, NullArray(0): means [] - // except([], []) = [], except([], null) = [], except(null, []) = null, except(null, null) = null - let nulls = match (array1.len(), array2.len()) { - (1, _) => Some(NullBuffer::new_null(1)), - _ => None, - }; - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Null, true)), - OffsetBuffer::new(vec![0; 2].into()), - Arc::new(NullArray::new(0)), - nulls, - )?) as ArrayRef; - Ok(arr) - } (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), (DataType::List(field), DataType::List(_)) => { check_datatypes("array_except", &[array1, array2])?; @@ -1525,36 +1510,31 @@ pub fn array_union(args: &[ArrayRef]) -> Result { let array1 = &args[0]; let array2 = &args[1]; match (array1.data_type(), array2.data_type()) { - (DataType::Null, DataType::Null) => { - // NullArray(1): means null, NullArray(0): means [] - // union([], []) = [], union([], null) = [], union(null, []) = [], union(null, null) = null - let nulls = match (array1.len(), array2.len()) { - (1, 1) => Some(NullBuffer::new_null(1)), - _ => None, - }; - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Null, true)), - OffsetBuffer::new(vec![0; 2].into()), - Arc::new(NullArray::new(0)), - nulls, - )?) as ArrayRef; - Ok(arr) - } (DataType::Null, _) => Ok(array2.clone()), (_, DataType::Null) => Ok(array1.clone()), - (DataType::List(field_ref), DataType::List(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, &l_field_ref)?; + Ok(Arc::new(result)) + } + } } - (DataType::LargeList(field_ref), DataType::LargeList(_)) => { - check_datatypes("array_union", &[array1, array2])?; - let list1 = array1.as_list::(); - let list2 = array2.as_list::(); - let result = union_generic_lists::(list1, list2, field_ref)?; - Ok(Arc::new(result)) + (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, &l_field_ref)?; + Ok(Arc::new(result)) + } + } } _ => { internal_err!( @@ -2032,21 +2012,8 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { let second_array = &args[1]; match (first_array.data_type(), second_array.data_type()) { - (DataType::Null, DataType::Null) => { - // NullArray(1): means null, NullArray(0): means [] - // intersect([], []) = [], intersect([], null) = [], intersect(null, []) = [], intersect(null, null) = null - let nulls = match (first_array.len(), second_array.len()) { - (1, 1) => Some(NullBuffer::new_null(1)), - _ => None, - }; - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Null, true)), - OffsetBuffer::new(vec![0; 2].into()), - Arc::new(NullArray::new(0)), - nulls, - )?) as ArrayRef; - Ok(arr) - } + (DataType::Null, _) => Ok(second_array.clone()), + (_, DataType::Null) => Ok(first_array.clone()), _ => { let first_array = as_list_array(&first_array)?; let second_array = as_list_array(&second_array)?; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3a06fdb158f76..aeee6ee55326f 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,20 +16,20 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; -use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { Expr::Literal(scalar) => { - values.push(scalar); + values.push(Expr::Literal(scalar)); + } + Expr::ScalarFunction(ref scalar_function) => { + if scalar_function.fun == BuiltinScalarFunction::MakeArray { + values.push(Expr::ScalarFunction(scalar_function.clone())); + } else { + return not_impl_err!( + "Arrays with elements other than literal are not supported: {value}" + ); + } } _ => { return not_impl_err!( @@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let data_types: HashSet = - values.iter().map(|e| e.data_type()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0)))) - } else if data_types.len() > 1 { - not_impl_err!("Arrays with different types are not supported: {data_types:?}") - } else { - let data_type = values[0].data_type(); - let arr = ScalarValue::new_list(&values, &data_type); - Ok(lit(ScalarValue::List(arr))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan