Skip to content

Commit

Permalink
Fix recursive flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Dec 13, 2024
1 parent 6903259 commit 1bd311a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
6 changes: 6 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature {
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
RecursiveArray,
/// Specialized Signature for MapArray
/// The function takes a single argument that must be a MapArray
MapArray,
Expand All @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::Array => {
write!(f, "array")
}
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
Expand Down
21 changes: 21 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{LogicalType, NativeType},
Expand Down Expand Up @@ -414,6 +415,7 @@ fn get_valid_types(
_ => Ok(vec![vec![]]),
}
}

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_) => Some(array_type.clone()),
Expand All @@ -424,6 +426,18 @@ fn get_valid_types(
}
}

fn recursive_array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => {
let array_type = coerced_fixed_size_list_to_list(array_type);
Some(array_type)
}
_ => None,
}
}

fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
if length < 1 {
return plan_err!(
Expand Down Expand Up @@ -651,6 +665,13 @@ fn get_valid_types(
array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
recursive_array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
Expand Down
11 changes: 9 additions & 2 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
Expand Down Expand Up @@ -56,7 +57,13 @@ impl Default for Flatten {
impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
signature: Signature {
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
},
aliases: vec![],
}
}
Expand Down

0 comments on commit 1bd311a

Please sign in to comment.