Skip to content

Commit

Permalink
make array_union/array_except/array_intersect handle empty/null…
Browse files Browse the repository at this point in the history
… arrays rightly (#8269)

* make array_union handle empty/null arrays rightly

Signed-off-by: veeupup <[email protected]>

* make array_except handle empty/null arrays rightly

Signed-off-by: veeupup <[email protected]>

* make array_intersect handle empty/null arrays rightly

Signed-off-by: veeupup <[email protected]>

* fix  sql_array_literal

Signed-off-by: veeupup <[email protected]>

* fix comments

---------

Signed-off-by: veeupup <[email protected]>
  • Loading branch information
Veeupup authored Nov 21, 2023
1 parent 47b4972 commit 54a0247
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 103 deletions.
18 changes: 15 additions & 3 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,24 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, dt) => Ok(dt),
(dt, DataType::Null) => Ok(dt),
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::Range => {
Ok(List(Arc::new(Field::new("item", Int64, true))))
}
BuiltinScalarFunction::ArrayExcept => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayExcept => {
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
(DataType::Null, _) | (_, DataType::Null) => {
Ok(input_expr_types[0].clone())
}
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
Expand Down
137 changes: 79 additions & 58 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>>

fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
if !args
.iter()
.all(|arg| arg.data_type().equals_datatype(data_type))
{
if !args.iter().all(|arg| {
arg.data_type().equals_datatype(data_type)
|| arg.data_type().equals_datatype(&DataType::Null)
}) {
let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
return plan_err!("{name} received incompatible types: '{types:?}'.");
}
Expand Down Expand Up @@ -1512,19 +1512,29 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(DataType::List(field_ref), DataType::List(_)) => {
check_datatypes("array_union", &[array1, array2])?;
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
Ok(Arc::new(result))
(DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
match (l_field_ref.data_type(), r_field_ref.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(_, _) => {
let list1 = array1.as_list::<i32>();
let list2 = array2.as_list::<i32>();
let result = union_generic_lists::<i32>(list1, list2, l_field_ref)?;
Ok(Arc::new(result))
}
}
}
(DataType::LargeList(field_ref), DataType::LargeList(_)) => {
check_datatypes("array_union", &[array1, array2])?;
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
Ok(Arc::new(result))
(DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => {
match (l_field_ref.data_type(), r_field_ref.data_type()) {
(DataType::Null, _) => Ok(array2.clone()),
(_, DataType::Null) => Ok(array1.clone()),
(_, _) => {
let list1 = array1.as_list::<i64>();
let list2 = array2.as_list::<i64>();
let result = union_generic_lists::<i64>(list1, list2, l_field_ref)?;
Ok(Arc::new(result))
}
}
}
_ => {
internal_err!(
Expand Down Expand Up @@ -1919,55 +1929,66 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);

let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;
let first_array = &args[0];
let second_array = &args[1];

if first_array.value_type() != second_array.value_type() {
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
}
let dt = first_array.value_type();
match (first_array.data_type(), second_array.data_type()) {
(DataType::Null, _) => Ok(second_array.clone()),
(_, DataType::Null) => Ok(first_array.clone()),
_ => {
let first_array = as_list_array(&first_array)?;
let second_array = as_list_array(&second_array)?;

let mut offsets = vec![0];
let mut new_arrays = vec![];

let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
if first_array.value_type() != second_array.value_type() {
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
}

let last_offset: i32 = match offsets.last().copied() {
Some(offset) => offset,
None => return internal_err!("offsets should not be empty"),
};
offsets.push(last_offset + rows.len() as i32);
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!(
"array_intersect: failed to get array from rows"
)
let dt = first_array.value_type();

let mut offsets = vec![0];
let mut new_arrays = vec![];

let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let values_set: HashSet<_> = l_values.iter().collect();
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}

let last_offset: i32 = match offsets.last().copied() {
Some(offset) => offset,
None => return internal_err!("offsets should not be empty"),
};
offsets.push(last_offset + rows.len() as i32);
let arrays = converter.convert_rows(rows)?;
let array = match arrays.get(0) {
Some(array) => array.clone(),
None => {
return internal_err!(
"array_intersect: failed to get array from rows"
)
}
};
new_arrays.push(array);
}
};
new_arrays.push(array);
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref =
new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}
}

let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = compute::concat(&new_arrays_ref)?;
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
Ok(arr)
}

#[cfg(test)]
Expand Down
34 changes: 18 additions & 16 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
// under the License.

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow::array::new_null_array;
use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano;
use arrow::datatypes::DECIMAL128_MAX_PRECISION;
use arrow_schema::DataType;
use datafusion_common::{
not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{BinaryExpr, Placeholder};
use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
use std::borrow::Cow;
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn parse_value(
Expand Down Expand Up @@ -138,9 +138,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
&mut PlannerContext::new(),
)?;

match value {
Expr::Literal(scalar) => {
values.push(scalar);
Expr::Literal(_) => {
values.push(value);
}
Expr::ScalarFunction(ref scalar_function) => {
if scalar_function.fun == BuiltinScalarFunction::MakeArray {
values.push(value);
} else {
return not_impl_err!(
"ScalarFunctions without MakeArray are not supported: {value}"
);
}
}
_ => {
return not_impl_err!(
Expand All @@ -150,18 +160,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

let data_types: HashSet<DataType> =
values.iter().map(|e| e.data_type()).collect();

if data_types.is_empty() {
Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0))))
} else if data_types.len() > 1 {
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
} else {
let data_type = values[0].data_type();
let arr = ScalarValue::new_list(&values, &data_type);
Ok(lit(ScalarValue::List(arr)))
}
Ok(Expr::ScalarFunction(ScalarFunction::new(
BuiltinScalarFunction::MakeArray,
values,
)))
}

/// Convert a SQL interval expression to a DataFusion logical plan
Expand Down
22 changes: 0 additions & 22 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1383,18 +1383,6 @@ fn select_interval_out_of_range() {
);
}

#[test]
fn select_array_no_common_type() {
let sql = "SELECT [1, true, null]";
let err = logical_plan(sql).expect_err("query should have failed");

// HashSet doesn't guarantee order
assert_contains!(
err.strip_backtrace(),
"This feature is not implemented: Arrays with different types are not supported: "
);
}

#[test]
fn recursive_ctes() {
let sql = "
Expand All @@ -1411,16 +1399,6 @@ fn recursive_ctes() {
);
}

#[test]
fn select_array_non_literal_type() {
let sql = "SELECT [now()]";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"This feature is not implemented: Arrays with elements other than literal are not supported: now()",
err.strip_backtrace()
);
}

#[test]
fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ SELECT COUNT(DISTINCT c1) FROM test
query ?
SELECT ARRAY_AGG([])
----
[]
[[]]

# array_agg_one
query ?
Expand All @@ -1419,7 +1419,7 @@ e 4
query ?
SELECT ARRAY_AGG([]);
----
[]
[[]]

# array_agg_one
query ?
Expand Down
Loading

0 comments on commit 54a0247

Please sign in to comment.