From 3b6aac2fcecdb003427f9475f061ed2cc52e8558 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 11 Oct 2024 13:13:36 +0800 Subject: [PATCH] Support struct coercion in `type_union_resolution` (#12839) * support strucy Signed-off-by: jayzhan211 * fix struct Signed-off-by: jayzhan211 * rm todo Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 * fix field order Signed-off-by: jayzhan211 * add lsit of stuct test Signed-off-by: jayzhan211 * upd err msg Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../expr-common/src/type_coercion/binary.rs | 46 +++++++- .../expr/src/type_coercion/functions.rs | 39 +++++-- datafusion/functions-nested/src/make_array.rs | 48 ++++++++ datafusion/sqllogictest/test_files/struct.slt | 109 +++++++++++++++++- 4 files changed, 228 insertions(+), 14 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 6d66b8b4df44..e042dd5d3ac6 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -25,8 +25,8 @@ use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; @@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory { /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted /// decimal precision and scale when coercing decimal types. +/// +/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; @@ -476,6 +478,46 @@ fn type_union_resolution_coercion( type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) } + (DataType::Struct(lhs), DataType::Struct(rhs)) => { + if lhs.len() != rhs.len() { + return None; + } + + // Search the field in the right hand side with the SAME field name + fn search_corresponding_coerced_type( + lhs_field: &FieldRef, + rhs: &Fields, + ) -> Option { + for rhs_field in rhs.iter() { + if lhs_field.name() == rhs_field.name() { + if let Some(t) = type_union_resolution_coercion( + lhs_field.data_type(), + rhs_field.data_type(), + ) { + return Some(t); + } else { + return None; + } + } + } + + None + } + + let types = lhs + .iter() + .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs)) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(DataType::Struct(fields.into())) + } _ => { // numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 143e00fa409e..85f8e20ba4a5 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -221,20 +221,37 @@ fn get_valid_types_with_scalar_udf( current_types: &[DataType], func: &ScalarUDF, ) -> Result>> { - let valid_types = match signature { + match signature { TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + Ok(coerced_types) => Ok(vec![coerced_types]), + Err(e) => exec_err!("User-defined coercion failed with {:?}", e), }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) - .flatten() - .collect::>(), - _ => get_valid_types(signature, current_types)?, - }; + TypeSignature::OneOf(signatures) => { + let mut res = vec![]; + let mut errors = vec![]; + for sig in signatures { + match get_valid_types_with_scalar_udf(sig, current_types, func) { + Ok(valid_types) => { + res.extend(valid_types); + } + Err(e) => { + errors.push(e.to_string()); + } + } + } - Ok(valid_types) + // Every signature failed, return the joined error + if res.is_empty() { + internal_err!( + "Failed to match any signature, errors: {}", + errors.join(",") + ) + } else { + Ok(res) + } + } + _ => get_valid_types(signature, current_types), + } } fn get_valid_types_with_aggregate_udf( diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 51fc71e6b09d..cafa073f9191 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -27,10 +27,12 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; +use datafusion_common::{exec_err, internal_err}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::binary::type_union_resolution; use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use itertools::Itertools; use crate::utils::make_scalar_function; @@ -106,6 +108,32 @@ impl ScalarUDFImpl for MakeArray { fn coerce_types(&self, arg_types: &[DataType]) -> Result> { if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move the logic to type_union_resolution if this applies to other functions as well + // Handle struct where we only change the data type but preserve the field name and nullability. + // Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" + let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?; + if is_struct_and_has_same_key { + let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] { + fields.iter().map(|f| f.data_type().to_owned()).collect() + } else { + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + let mut final_struct_types = vec![]; + for s in arg_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(data_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + return Ok(final_struct_types); + } + if let DataType::FixedSizeList(field, _) = new_type { Ok(vec![DataType::List(field); arg_types.len()]) } else if new_type.is_null() { @@ -123,6 +151,26 @@ impl ScalarUDFImpl for MakeArray { } } +fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + } + } else { + keys_string = Some(keys); + } + } else { + return Ok(false); + } + } + + Ok(true) +} + // Empty array is a special case that is useful for many other array functions pub(super) fn empty_array_type() -> DataType { DataType::List(Arc::new(Field::new("item", DataType::Int64, true))) diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 67cd7d71fc1c..b76c78396aed 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -374,6 +374,34 @@ You reached the bottom! statement ok drop view complex_view; +# struct with different keys r1 and r2 is not valid +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +# Expect same keys for struct type but got mismatched pair r1,c and r2,c +query error +select [a, b] from t; + +statement ok +drop table t; + +# struct with the same key +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query ? +select [a, b] from t; +---- +[{r: red, c: 1}, {r: blue, c: 2}] + +statement ok +drop table t; + # Test row alias query ? @@ -412,7 +440,6 @@ select * from t; ---- {r: red, b: 2} {r: blue, b: 2.3} -# TODO: Should be coerced to float query T select arrow_typeof(c1) from t; ---- @@ -422,3 +449,83 @@ query T select arrow_typeof(c2) from t; ---- Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +################################## +## Test Coalesce with Struct +################################## + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (row(2, 'blue'), row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +{a: 2, b: blue} +{a: 3, b: green} + +# TODO: a's type should be float +query T +select arrow_typeof(coalesce(s1)) from t; +---- +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (null, row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +# TODO: second column should not be null +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +NULL +{a: 3, b: green} + +statement ok +drop table t; + +# row() with incorrect order +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float64 type +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values + (row('red', 1), row(2.3, 'blue')), + (row('purple', 1), row('green', 2.3)); + +# out of order struct literal +# TODO: This query should not fail +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); + +################################## +## Test Array of Struct +################################## + +query ? +select [{r: 'a', c: 1}, {r: 'b', c: 2}]; +---- +[{r: a, c: 1}, {r: b, c: 2}] + +# Can't create a list of struct with different field types +query error +select [{r: 'a', c: 1}, {c: 2, r: 'b'}];