From 7e0fc146add0a29cd2e63df9fb82097e76c19a67 Mon Sep 17 00:00:00 2001
From: Piotr Findeisen <piotr.findeisen@gmail.com>
Date: Wed, 18 Dec 2024 08:15:38 +0100
Subject: [PATCH] Fix get_type for higher-order array functions (#13756)

* Fix get_type for higher-order array functions

* Fix recursive flatten

The fix is covered by recursive flatten test case in array.slt

* Restore "keep LargeList" in Array signature

* clarify naming in the test
---
 datafusion/expr-common/src/signature.rs       |  6 ++
 .../expr/src/type_coercion/functions.rs       | 19 ++++-
 datafusion/functions-nested/src/extract.rs    | 83 +++++++++++++++++++
 datafusion/functions-nested/src/flatten.rs    | 11 ++-
 4 files changed, 116 insertions(+), 3 deletions(-)

diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs
index 148ddac73a57..4f97dfe9c8f0 100644
--- a/datafusion/expr-common/src/signature.rs
+++ b/datafusion/expr-common/src/signature.rs
@@ -204,6 +204,9 @@ pub enum ArrayFunctionSignature {
     /// The function takes a single argument that must be a List/LargeList/FixedSizeList
     /// or something that can be coerced to one of those types.
     Array,
+    /// A function takes a single argument that must be a List/LargeList/FixedSizeList
+    /// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
+    RecursiveArray,
     /// Specialized Signature for MapArray
     /// The function takes a single argument that must be a MapArray
     MapArray,
@@ -227,6 +230,9 @@ impl Display for ArrayFunctionSignature {
             ArrayFunctionSignature::Array => {
                 write!(f, "array")
             }
+            ArrayFunctionSignature::RecursiveArray => {
+                write!(f, "recursive_array")
+            }
             ArrayFunctionSignature::MapArray => {
                 write!(f, "map_array")
             }
diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs
index b12489167b8f..c20625cbc2f6 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -21,10 +21,11 @@ use arrow::{
     compute::can_cast_types,
     datatypes::{DataType, TimeUnit},
 };
+use datafusion_common::utils::coerced_fixed_size_list_to_list;
 use datafusion_common::{
     exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
     types::{LogicalType, NativeType},
-    utils::{coerced_fixed_size_list_to_list, list_ndims},
+    utils::list_ndims,
     Result,
 };
 use datafusion_expr_common::{
@@ -418,7 +419,16 @@ fn get_valid_types(
             _ => Ok(vec![vec![]]),
         }
     }
+
     fn array(array_type: &DataType) -> Option<DataType> {
+        match array_type {
+            DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
+            DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
+            _ => None,
+        }
+    }
+
+    fn recursive_array(array_type: &DataType) -> Option<DataType> {
         match array_type {
             DataType::List(_)
             | DataType::LargeList(_)
@@ -687,6 +697,13 @@ fn get_valid_types(
                 array(&current_types[0])
                     .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
             }
+            ArrayFunctionSignature::RecursiveArray => {
+                if current_types.len() != 1 {
+                    return Ok(vec![vec![]]);
+                }
+                recursive_array(&current_types[0])
+                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
+            }
             ArrayFunctionSignature::MapArray => {
                 if current_types.len() != 1 {
                     return Ok(vec![vec![]]);
diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs
index fc35f0076330..f972597bbf84 100644
--- a/datafusion/functions-nested/src/extract.rs
+++ b/datafusion/functions-nested/src/extract.rs
@@ -993,3 +993,86 @@ where
     let data = mutable.freeze();
     Ok(arrow::array::make_array(data))
 }
+
+#[cfg(test)]
+mod tests {
+    use super::array_element_udf;
+    use arrow_schema::{DataType, Field};
+    use datafusion_common::{Column, DFSchema, ScalarValue};
+    use datafusion_expr::expr::ScalarFunction;
+    use datafusion_expr::{cast, Expr, ExprSchemable};
+    use std::collections::HashMap;
+
+    // Regression test for https://github.com/apache/datafusion/issues/13755
+    #[test]
+    fn test_array_element_return_type_fixed_size_list() {
+        let fixed_size_list_type = DataType::FixedSizeList(
+            Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
+            13,
+        );
+        let array_type = DataType::List(
+            Field::new_list_field(fixed_size_list_type.clone(), true).into(),
+        );
+        let index_type = DataType::Int64;
+
+        let schema = DFSchema::from_unqualified_fields(
+            vec![
+                Field::new("my_array", array_type.clone(), false),
+                Field::new("my_index", index_type.clone(), false),
+            ]
+            .into(),
+            HashMap::default(),
+        )
+        .unwrap();
+
+        let udf = array_element_udf();
+
+        // ScalarUDFImpl::return_type
+        assert_eq!(
+            udf.return_type(&[array_type.clone(), index_type.clone()])
+                .unwrap(),
+            fixed_size_list_type
+        );
+
+        // ScalarUDFImpl::return_type_from_exprs with typed exprs
+        assert_eq!(
+            udf.return_type_from_exprs(
+                &[
+                    cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
+                    cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
+                ],
+                &schema,
+                &[array_type.clone(), index_type.clone()]
+            )
+            .unwrap(),
+            fixed_size_list_type
+        );
+
+        // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
+        assert_eq!(
+            udf.return_type_from_exprs(
+                &[
+                    Expr::Column(Column::new_unqualified("my_array")),
+                    Expr::Column(Column::new_unqualified("my_index")),
+                ],
+                &schema,
+                &[array_type.clone(), index_type.clone()]
+            )
+            .unwrap(),
+            fixed_size_list_type
+        );
+
+        // Via ExprSchemable::get_type (e.g. SimplifyInfo)
+        let udf_expr = Expr::ScalarFunction(ScalarFunction {
+            func: array_element_udf(),
+            args: vec![
+                Expr::Column(Column::new_unqualified("my_array")),
+                Expr::Column(Column::new_unqualified("my_index")),
+            ],
+        });
+        assert_eq!(
+            ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
+            fixed_size_list_type
+        );
+    }
+}
diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs
index 9d2cb8a3f667..7cb52ae4c5c9 100644
--- a/datafusion/functions-nested/src/flatten.rs
+++ b/datafusion/functions-nested/src/flatten.rs
@@ -28,7 +28,8 @@ use datafusion_common::cast::{
 use datafusion_common::{exec_err, Result};
 use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
 use datafusion_expr::{
-    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
+    TypeSignature, Volatility,
 };
 use std::any::Any;
 use std::sync::{Arc, OnceLock};
@@ -56,7 +57,13 @@ impl Default for Flatten {
 impl Flatten {
     pub fn new() -> Self {
         Self {
-            signature: Signature::array(Volatility::Immutable),
+            signature: Signature {
+                // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
+                type_signature: TypeSignature::ArraySignature(
+                    ArrayFunctionSignature::RecursiveArray,
+                ),
+                volatility: Volatility::Immutable,
+            },
             aliases: vec![],
         }
     }