Skip to content

Commit

Permalink
fix sql_array_literal
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup committed Nov 20, 2023
1 parent c08d6cb commit a76beaa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 91 deletions.
21 changes: 6 additions & 15 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,19 +599,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => {
BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::ArrayUnion => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(DataType::Null, dt) => Ok(dt),
(dt, DataType::Null) => Ok(dt),
(dt, _) => Ok(dt),
}
}
Expand All @@ -620,9 +611,9 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayExcept => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new(
Field::new("item", DataType::Null, true),
))),
(DataType::Null, _) | (_, DataType::Null) => {
Ok(input_expr_types[0].clone())
}
(dt, _) => Ok(dt),
}
}
Expand Down
89 changes: 28 additions & 61 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>>

fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
if !args
.iter()
.all(|arg| arg.data_type().equals_datatype(data_type))
{
if !args.iter().all(|arg| {
arg.data_type().equals_datatype(data_type)
|| arg.data_type().equals_datatype(&DataType::Null)
}) {
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
return plan_err!("{name} received incompatible types: '{types:?}'.");
}
Expand Down Expand Up @@ -580,21 +580,6 @@ pub fn array_except(args: &[ArrayRef]) -> Result<ArrayRef> {
let array2 = &args[1];

match (array1.data_type(), array2.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// except([], []) = [], except([], null) = [], except(null, []) = null, except(null, null) = null
let nulls = match (array1.len(), array2.len()) {
(1, _) => Some(NullBuffer::new_null(1)),
_ => None,
};
let arr = Arc::new(ListArray::try_new(
Arc::new(Field::new("item", DataType::Null, true)),
OffsetBuffer::new(vec![0; 2].into()),
Arc::new(NullArray::new(0)),
nulls,
)?) as ArrayRef;
Ok(arr)
}
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
(DataType::List(field), DataType::List(_)) => {
check_datatypes("array_except", &[array1, array2])?;
Expand Down Expand Up @@ -1525,36 +1510,31 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
let array1 = &args[0];
let array2 = &args[1];
match (array1.data_type(), array2.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// union([], []) = [], union([], null) = [], union(null, []) = [], union(null, null) = null
let nulls = match (array1.len(), array2.len()) {
(1, 1) => Some(NullBuffer::new_null(1)),
_ => None,
};
let arr = Arc::new(ListArray::try_new(
Arc::new(Field::new("item", DataType::Null, true)),
OffsetBuffer::new(vec![0; 2].into()),
Arc::new(NullArray::new(0)),
nulls,
)?) as ArrayRef;
Ok(arr)
}
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(DataType::List(field_ref), DataType::List(_)) => {
check_datatypes("array_union", &[array1, array2])?;
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
Ok(Arc::new(result))
(DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
match (l_field_ref.data_type(), r_field_ref.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(_, _) => {
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = union_generic_lists::<i32>(list1, list2, &l_field_ref)?;
Ok(Arc::new(result))
}
}
}
(DataType::LargeList(field_ref), DataType::LargeList(_)) => {
check_datatypes("array_union", &[array1, array2])?;
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
Ok(Arc::new(result))
(DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => {
match (l_field_ref.data_type(), r_field_ref.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(_, _) => {
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = union_generic_lists::<i64>(list1, list2, &l_field_ref)?;
Ok(Arc::new(result))
}
}
}
_ => {
internal_err!(
Expand Down Expand Up @@ -2032,21 +2012,8 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
let second_array = &args[1];

match (first_array.data_type(), second_array.data_type()) {
(DataType::Null, DataType::Null) => {
// NullArray(1): means null, NullArray(0): means []
// intersect([], []) = [], intersect([], null) = [], intersect(null, []) = [], intersect(null, null) = null
let nulls = match (first_array.len(), second_array.len()) {
(1, 1) => Some(NullBuffer::new_null(1)),
_ => None,
};
let arr = Arc::new(ListArray::try_new(
Arc::new(Field::new("item", DataType::Null, true)),
OffsetBuffer::new(vec![0; 2].into()),
Arc::new(NullArray::new(0)),
nulls,
)?) as ArrayRef;
Ok(arr)
}
(DataType::Null, _) => Ok(second_array.clone()),
(_, DataType::Null) => Ok(first_array.clone()),
_ => {
let first_array = as_list_array(&first_array)?;
let second_array = as_list_array(&second_array)?;
Expand Down
32 changes: 17 additions & 15 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
// under the License.

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow::array::new_null_array;
use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano;
use arrow::datatypes::DECIMAL128_MAX_PRECISION;
use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{BinaryExpr, Placeholder};
use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
use std::borrow::Cow;
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn parse_value(
Expand Down Expand Up @@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
&mut PlannerContext::new(),
)?;

match value {
Expr::Literal(scalar) => {
values.push(scalar);
values.push(Expr::Literal(scalar));
}
Expr::ScalarFunction(ref scalar_function) => {
if scalar_function.fun == BuiltinScalarFunction::MakeArray {
values.push(Expr::ScalarFunction(scalar_function.clone()));
} else {
return not_impl_err!(
"Arrays with elements other than literal are not supported: {value}"
);
}
}
_ => {
return not_impl_err!(
Expand All @@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

let data_types: HashSet<DataType> =
values.iter().map(|e| e.data_type()).collect();

if data_types.is_empty() {
Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0))))
} else if data_types.len() > 1 {
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
} else {
let data_type = values[0].data_type();
let arr = ScalarValue::new_list(&values, &data_type);
Ok(lit(ScalarValue::List(arr)))
}
Ok(Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::MakeArray,
values,
)))
}

/// Convert a SQL interval expression to a DataFusion logical plan
Expand Down

0 comments on commit a76beaa

Please sign in to comment.