Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve array_concat signature for null and empty array #8594

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 22 additions & 30 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{
};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::utils::list_ndims;
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};

use strum::IntoEnumIterator;
Expand Down Expand Up @@ -498,25 +499,6 @@ impl BuiltinScalarFunction {
}
}

/// Returns the dimension [`DataType`] of [`DataType::List`] if
/// treated as a N-dimensional array.
///
/// ## Examples:
///
/// * `Int64` has dimension 1
/// * `List(Int64)` has dimension 2
/// * `List(List(Int64))` has dimension 3
/// * etc.
fn return_dimension(self, input_expr_type: &DataType) -> u64 {
let mut result: u64 = 1;
let mut current_data_type = input_expr_type;
while let DataType::List(field) = current_data_type {
current_data_type = field.data_type();
result += 1;
}
result
}

/// Returns the output [`DataType`] of this function
///
/// This method should be invoked only after `input_expr_types` have been validated
Expand Down Expand Up @@ -552,25 +534,30 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
let mut expr_type: Option<DataType> = None;
let mut max_dims = 0;
for input_expr_type in input_expr_types {
match input_expr_type {
List(field) => {
if !field.data_type().equals_datatype(&Null) {
let dims = self.return_dimension(input_expr_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
List(_) => {
let dims = list_ndims(input_expr_type);
if let Some(data_type) = expr_type {
let new_type = match max_dims.cmp(&dims) {
Ordering::Greater => data_type,
Ordering::Equal => {
get_wider_type(&expr_type, input_expr_type)?
get_wider_type(&data_type, input_expr_type)?
}
Ordering::Less => {
max_dims = dims;
input_expr_type.clone()
}
};
expr_type = Some(new_type)
} else {
expr_type = Some(input_expr_type.clone());
max_dims = dims;
}
}
DataType::Null => {}
_ => {
return plan_err!(
"The {self} function can only accept list as the args."
Expand All @@ -579,7 +566,11 @@ impl BuiltinScalarFunction {
}
}

Ok(expr_type)
if let Some(expr_type) = expr_type {
Ok(expr_type)
} else {
Ok(DataType::Null)
}
}
BuiltinScalarFunction::ArrayHasAll
| BuiltinScalarFunction::ArrayHasAny
Expand Down Expand Up @@ -929,9 +920,10 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayConcat => Signature {
type_signature: ArrayConcat,
volatility: self.volatility(),
},
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ pub enum TypeSignature {
/// List dimension of the List/LargeList is equivalent to the number of List.
/// List dimension of the non-list is 0.
ArrayAndElement,
/// Specialized Signature for ArrayConcat
/// Accept arbitrary arguments but they SHOULD be List/LargeList or Null, and the list dimension MAY NOT be the same.
ArrayConcat,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any usecase for this type of signature other than ArrayConcat? If there is only a single function that would likely have this signature, the code probably doesn't belong in the TypeSignature enum

I don't fully understand

and the list dimension MAY NOT be the same.

For example, one of the tests is

select array_concat([1, null], [null]);

Doesn't that have two arguments of the same dimension?

Copy link
Contributor Author

@jayzhan211 jayzhan211 Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, where should be handled arrayConcat? coerce argument for fun?. I think there is no other array function to share the same coercion with arrayConcat

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ist dimension MAY NOT be the same.
I mean MAY or MAY NOT.
array concat is able to concat different dimension array in DF.

select array_concat([1, null], [null]); is the normal cases, dimensions are both 1.
select array_concat([1, null], null); is also valid, it acts like array_append, where dimensions are 1 and 0.

Copy link
Member

@Weijun-H Weijun-H Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I worked on #8902, I thought this pr would benefit functions like array_union, and array_intersect. Their arguments are array and array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Weijun-H But array_concat accepts list dimension with 1 and 3 (any n and m). union, intersect should be n and n-1

}

impl TypeSignature {
Expand Down Expand Up @@ -155,6 +158,9 @@ impl TypeSignature {
TypeSignature::ArrayAndElement => {
vec!["ArrayAndElement(List<T>, T)".to_string()]
}
TypeSignature::ArrayConcat => {
vec!["ArrayConcat(List<T> / NULL, .., List<T> / NULL)".to_string()]
}
}
}

Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use arrow::datatypes::{
};

use datafusion_common::{
exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result,
exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError,
Result,
};

/// The type signature of an instantiation of binary operator expression such as
Expand Down Expand Up @@ -300,6 +301,21 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| binary_coercion(lhs_type, rhs_type))
}

/// Coerce all the data types into one final coerced type with `comparison_coercion`
pub fn comparison_coercion_for_iter(data_types: &[DataType]) -> Result<DataType> {
data_types.iter().skip(1).try_fold(
data_types.first().unwrap().clone(),
|current_type, other_type| {
let coerced_type = comparison_coercion(&current_type, other_type);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {current_type:?} to {other_type:?} failed.")
}
},
)
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
Expand Down
42 changes: 29 additions & 13 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::{
use datafusion_common::utils::list_ndims;
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};

use super::binary::comparison_coercion;
use super::binary::{comparison_coercion, comparison_coercion_for_iter};

/// Performs type coercion for function arguments.
///
Expand Down Expand Up @@ -89,18 +89,7 @@ fn get_valid_types(
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
);

let new_type = comparison_coercion_for_iter(current_types);
match new_type {
Ok(new_type) => vec![vec![new_type; current_types.len()]],
Err(e) => return Err(e),
Expand Down Expand Up @@ -149,6 +138,33 @@ fn get_valid_types(
return Ok(vec![vec![]]);
}
}
TypeSignature::ArrayConcat => {
let base_types = current_types
.iter()
.map(datafusion_common::utils::base_type)
.collect::<Vec<_>>();

let new_base_type = comparison_coercion_for_iter(base_types.as_slice());
match new_base_type {
Ok(new_base_type) => {
let array_types = current_types
.iter()
.map(|t| {
if t.eq(&DataType::Null) {
t.to_owned()
} else {
datafusion_common::utils::coerced_type_with_base_type_only(
t,
&new_base_type,
)
}
})
.collect::<Vec<_>>();
return Ok(vec![array_types]);
}
Err(e) => return Err(e),
}
}
TypeSignature::Any(number) => {
if current_types.len() != *number {
return plan_err!(
Expand Down
19 changes: 13 additions & 6 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_common::cast::{
as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array,
as_list_array, as_null_array, as_string_array,
};
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result,
};
Expand Down Expand Up @@ -1084,15 +1084,22 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut new_args = vec![];
for arg in args {
let ndim = list_ndims(arg.data_type());
let base_type = datafusion_common::utils::base_type(arg.data_type());
if ndim == 0 {
return not_impl_err!("Array is not type '{base_type:?}'.");
} else if !base_type.eq(&DataType::Null) {
let data_type = arg.data_type();
if let DataType::List(_) = data_type {
new_args.push(arg.clone());
} else if data_type.eq(&DataType::Null) {
// Null type is valid.
continue;
} else {
return internal_err!("Expect Array type, found {:?}", data_type);
}
}

// All the arguments are null, return null
if new_args.is_empty() {
return Ok(new_null_array(&DataType::Null, 0));
}

concat_internal(new_args.as_slice())
}

Expand Down
91 changes: 87 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1609,22 +1609,25 @@ select array_concat(make_array(), make_array(2, 3));
[2, 3]

# array_concat scalar function #7 (with empty arrays)
## DuckDB and ClickHouse both return '[[1, 2], [3, 4], []]'
query ?
select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()));
----
[[1, 2], [3, 4]]
[[1, 2], [3, 4], []]

# array_concat scalar function #8 (with empty arrays)
## DuckDB return error, ClickHouse return '[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]]'
query ?
select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()), make_array(make_array(), make_array()), make_array(make_array(5, 6), make_array(7, 8)));
select array_concat([[1,2], [3,4]], [[]], [[],[]], [[5,6], [7,8]]);
----
[[1, 2], [3, 4], [5, 6], [7, 8]]
[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]]

# array_concat scalar function #9 (with empty arrays)
## DuckDB and ClickHouse both return '[[], [1, 2], [3, 4]]'
query ?
select array_concat(make_array(make_array()), make_array(make_array(1, 2), make_array(3, 4)));
----
[[1, 2], [3, 4]]
[[], [1, 2], [3, 4]]

# array_cat scalar function #10 (function alias `array_concat`)
query ??
Expand Down Expand Up @@ -1818,6 +1821,86 @@ select array_concat(make_array(column3), column1, column2) from arrays_values_v2
[, 11, 12]
[]

# array concat with nulls
query ?
select array_concat([1,2,3], null);
----
[1, 2, 3]

query ?
select array_concat(null, [1,2,3]);
----
[1, 2, 3]

query ?
select array_concat([1, null, 2], [3, 4, null]);
----
[1, , 2, 3, 4, ]

query ?
select array_concat([1,2], [null,3]);
----
[1, 2, , 3]

query ?
select array_concat([[1,2]], [[null,3]]);
----
[[1, 2], [, 3]]

query ?
select array_concat([1, null], [[null, 2]]);
----
[[1, ], [, 2]]

query ?
select array_concat([1, null], [null]);
----
[1, , ]

query ?
select array_concat(null, null);
----
NULL

query ?
select array_concat([], null);
----
[]

query ?
select array_concat([], []);
----
[]

query ?
select array_concat([null], [null]);
----
[, ]

# 3D null + 1D + 2D empty
query ?
select array_concat([[[null]]], [1, 2], [[]]);
----
[[[]], [[1, 2]], [[]]]

# 1D + 2D + 3D empty
query ?
select array_concat([], [[]], [[[]]]);
----
[[[]], [[]], [[]]]

# 1D + 2D + 3D null
query ?
select array_concat([null], [[null]], [[[null]]]);
----
[[[]], [[]], [[]]]

# 0D + 1D + 2D + 3D null
query ?
select array_concat(null, [null], [[null]], [[[null]]]);
----
[[[]], [[]], [[]]]

## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`)

# array_position scalar function #1
Expand Down