Skip to content

Commit

Permalink
Fix get_type for higher-order array functions (#13756)
Browse files Browse the repository at this point in the history
* Fix get_type for higher-order array functions

* Fix recursive flatten

The fix is covered by recursive flatten test case in array.slt

* Restore "keep LargeList" in Array signature

* clarify naming in the test
  • Loading branch information
findepi authored Dec 18, 2024
1 parent 5500b11 commit 7e0fc14
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 3 deletions.
6 changes: 6 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ pub enum ArrayFunctionSignature {
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
RecursiveArray,
/// Specialized Signature for MapArray
/// The function takes a single argument that must be a MapArray
MapArray,
Expand All @@ -227,6 +230,9 @@ impl Display for ArrayFunctionSignature {
ArrayFunctionSignature::Array => {
write!(f, "array")
}
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
Expand Down
19 changes: 18 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
utils::list_ndims,
Result,
};
use datafusion_expr_common::{
Expand Down Expand Up @@ -418,7 +419,16 @@ fn get_valid_types(
_ => Ok(vec![vec![]]),
}
}

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
_ => None,
}
}

fn recursive_array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
| DataType::LargeList(_)
Expand Down Expand Up @@ -687,6 +697,13 @@ fn get_valid_types(
array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
recursive_array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
Expand Down
83 changes: 83 additions & 0 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,3 +993,86 @@ where
let data = mutable.freeze();
Ok(arrow::array::make_array(data))
}

#[cfg(test)]
mod tests {
use super::array_element_udf;
use arrow_schema::{DataType, Field};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{cast, Expr, ExprSchemable};
use std::collections::HashMap;

// Regression test for https://github.com/apache/datafusion/issues/13755
#[test]
fn test_array_element_return_type_fixed_size_list() {
let fixed_size_list_type = DataType::FixedSizeList(
Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
13,
);
let array_type = DataType::List(
Field::new_list_field(fixed_size_list_type.clone(), true).into(),
);
let index_type = DataType::Int64;

let schema = DFSchema::from_unqualified_fields(
vec![
Field::new("my_array", array_type.clone(), false),
Field::new("my_index", index_type.clone(), false),
]
.into(),
HashMap::default(),
)
.unwrap();

let udf = array_element_udf();

// ScalarUDFImpl::return_type
assert_eq!(
udf.return_type(&[array_type.clone(), index_type.clone()])
.unwrap(),
fixed_size_list_type
);

// ScalarUDFImpl::return_type_from_exprs with typed exprs
assert_eq!(
udf.return_type_from_exprs(
&[
cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
fixed_size_list_type
);

// ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
assert_eq!(
udf.return_type_from_exprs(
&[
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
fixed_size_list_type
);

// Via ExprSchemable::get_type (e.g. SimplifyInfo)
let udf_expr = Expr::ScalarFunction(ScalarFunction {
func: array_element_udf(),
args: vec![
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
});
assert_eq!(
ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
fixed_size_list_type
);
}
}
11 changes: 9 additions & 2 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
Expand Down Expand Up @@ -56,7 +57,13 @@ impl Default for Flatten {
impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
signature: Signature {
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
},
aliases: vec![],
}
}
Expand Down

0 comments on commit 7e0fc14

Please sign in to comment.