Skip to content

Commit

Permalink
Support struct coercion in type_union_resolution (#12839)
Browse files Browse the repository at this point in the history
* support strucy

Signed-off-by: jayzhan211 <[email protected]>

* fix struct

Signed-off-by: jayzhan211 <[email protected]>

* rm todo

Signed-off-by: jayzhan211 <[email protected]>

* add more test

Signed-off-by: jayzhan211 <[email protected]>

* fix field order

Signed-off-by: jayzhan211 <[email protected]>

* add lsit of stuct test

Signed-off-by: jayzhan211 <[email protected]>

* upd err msg

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Oct 11, 2024
1 parent 58c32cb commit 3b6aac2
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 14 deletions.
46 changes: 44 additions & 2 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::operator::Operator;
use arrow::array::{new_empty_array, Array};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result};

Expand Down Expand Up @@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory {
/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules
/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted
/// decimal precision and scale when coercing decimal types.
///
/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type.
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
Expand Down Expand Up @@ -476,6 +478,46 @@ fn type_union_resolution_coercion(
type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
}
(DataType::Struct(lhs), DataType::Struct(rhs)) => {
if lhs.len() != rhs.len() {
return None;
}

// Search the field in the right hand side with the SAME field name
fn search_corresponding_coerced_type(
lhs_field: &FieldRef,
rhs: &Fields,
) -> Option<DataType> {
for rhs_field in rhs.iter() {
if lhs_field.name() == rhs_field.name() {
if let Some(t) = type_union_resolution_coercion(
lhs_field.data_type(),
rhs_field.data_type(),
) {
return Some(t);
} else {
return None;
}
}
}

None
}

let types = lhs
.iter()
.map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs))
.collect::<Option<Vec<_>>>()?;

let fields = types
.into_iter()
.enumerate()
.map(|(i, datatype)| {
Arc::new(Field::new(format!("c{i}"), datatype, true))
})
.collect::<Vec<FieldRef>>();
Some(DataType::Struct(fields.into()))
}
_ => {
// numeric coercion is the same as comparison coercion, both find the narrowest type
// that can accommodate both types
Expand Down
39 changes: 28 additions & 11 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,37 @@ fn get_valid_types_with_scalar_udf(
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
Ok(coerced_types) => Ok(vec![coerced_types]),
Err(e) => exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok())
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};
TypeSignature::OneOf(signatures) => {
let mut res = vec![];
let mut errors = vec![];
for sig in signatures {
match get_valid_types_with_scalar_udf(sig, current_types, func) {
Ok(valid_types) => {
res.extend(valid_types);
}
Err(e) => {
errors.push(e.to_string());
}
}
}

Ok(valid_types)
// Every signature failed, return the joined error
if res.is_empty() {
internal_err!(
"Failed to match any signature, errors: {}",
errors.join(",")
)
} else {
Ok(res)
}
}
_ => get_valid_types(signature, current_types),
}
}

fn get_valid_types_with_aggregate_udf(
Expand Down
48 changes: 48 additions & 0 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, internal_err};
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;

use crate::utils::make_scalar_function;

Expand Down Expand Up @@ -106,6 +108,32 @@ impl ScalarUDFImpl for MakeArray {

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if let Some(new_type) = type_union_resolution(arg_types) {
// TODO: Move the logic to type_union_resolution if this applies to other functions as well
// Handle struct where we only change the data type but preserve the field name and nullability.
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?;
if is_struct_and_has_same_key {
let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] {
fields.iter().map(|f| f.data_type().to_owned()).collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

let mut final_struct_types = vec![];
for s in arg_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(data_types[i].to_owned());
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}
return Ok(final_struct_types);
}

if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Expand All @@ -123,6 +151,26 @@ impl ScalarUDFImpl for MakeArray {
}
}

fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
}
} else {
keys_string = Some(keys);
}
} else {
return Ok(false);
}
}

Ok(true)
}

// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
Expand Down
109 changes: 108 additions & 1 deletion datafusion/sqllogictest/test_files/struct.slt
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,34 @@ You reached the bottom!
statement ok
drop view complex_view;

# struct with different keys r1 and r2 is not valid
statement ok
create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));

# Expect same keys for struct type but got mismatched pair r1,c and r2,c
query error
select [a, b] from t;

statement ok
drop table t;

# struct with the same key
statement ok
create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as values (struct('red', 1), struct('blue', 2.3));

query T
select arrow_typeof([a, b]) from t;
----
List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })

query ?
select [a, b] from t;
----
[{r: red, c: 1}, {r: blue, c: 2}]

statement ok
drop table t;

# Test row alias

query ?
Expand Down Expand Up @@ -412,7 +440,6 @@ select * from t;
----
{r: red, b: 2} {r: blue, b: 2.3}

# TODO: Should be coerced to float
query T
select arrow_typeof(c1) from t;
----
Expand All @@ -422,3 +449,83 @@ query T
select arrow_typeof(c2) from t;
----
Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])

statement ok
drop table t;

##################################
## Test Coalesce with Struct
##################################

statement ok
CREATE TABLE t (
s1 struct(a int, b varchar),
s2 struct(a float, b varchar)
) AS VALUES
(row(1, 'red'), row(1.1, 'string1')),
(row(2, 'blue'), row(2.2, 'string2')),
(row(3, 'green'), row(33.2, 'string3'))
;

query ?
select coalesce(s1) from t;
----
{a: 1, b: red}
{a: 2, b: blue}
{a: 3, b: green}

# TODO: a's type should be float
query T
select arrow_typeof(coalesce(s1)) from t;
----
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])

statement ok
drop table t;

statement ok
CREATE TABLE t (
s1 struct(a int, b varchar),
s2 struct(a float, b varchar)
) AS VALUES
(row(1, 'red'), row(1.1, 'string1')),
(null, row(2.2, 'string2')),
(row(3, 'green'), row(33.2, 'string3'))
;

# TODO: second column should not be null
query ?
select coalesce(s1) from t;
----
{a: 1, b: red}
NULL
{a: 3, b: green}

statement ok
drop table t;

# row() with incorrect order
statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float64 type
create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values
(row('red', 1), row(2.3, 'blue')),
(row('purple', 1), row('green', 2.3));

# out of order struct literal
# TODO: This query should not fail
statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'});

##################################
## Test Array of Struct
##################################

query ?
select [{r: 'a', c: 1}, {r: 'b', c: 2}];
----
[{r: a, c: 1}, {r: b, c: 2}]

# Can't create a list of struct with different field types
query error
select [{r: 'a', c: 1}, {c: 2, r: 'b'}];

0 comments on commit 3b6aac2

Please sign in to comment.