Skip to content

Commit

Permalink
put all type coercion in coerce_arguments_for_signature
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Jan 30, 2024
1 parent 76cede4 commit b1d79ba
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 56 deletions.
87 changes: 39 additions & 48 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,62 +464,53 @@ pub fn coerced_type_with_base_type_only(
base_type: &DataType,
) -> DataType {
match data_type {
DataType::List(field)
| DataType::FixedSizeList(field, _)
| DataType::LargeList(field) => {
let field_type = match field.data_type() {
// nested type could be different list type
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
};
if matches!(data_type, DataType::LargeList(_)) {
DataType::LargeList(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
} else {
DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
DataType::List(field) | DataType::FixedSizeList(field, _) => {
let field_type =
coerced_type_with_base_type_only(field.data_type(), base_type);

DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
DataType::LargeList(field) => {
let field_type =
coerced_type_with_base_type_only(field.data_type(), base_type);

DataType::LargeList(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}

_ => base_type.clone(),
}
}

pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
match data_type {
DataType::FixedSizeList(field, _) => {
let field_type = match field.data_type() {
DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_) => {
coerced_fixed_size_list_to_list(field.data_type())
}
_ => field.data_type().to_owned(),
};
if matches!(data_type, DataType::LargeList(_)) {
DataType::LargeList(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
} else {
DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
DataType::List(field) | DataType::FixedSizeList(field, _) => {
let field_type = coerced_fixed_size_list_to_list(field.data_type());

DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
_ => data_type.to_owned(),
DataType::LargeList(field) => {
let field_type = coerced_fixed_size_list_to_list(field.data_type());

DataType::LargeList(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}

_ => data_type.clone(),
}
}

Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::list_ndims;
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, DataFusionError, Result,
};
Expand Down Expand Up @@ -141,7 +141,8 @@ fn get_valid_types(
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => {
Ok(vec![vec![array_type.clone(), DataType::Int64]])
let array_type = coerced_fixed_size_list_to_list(array_type);
Ok(vec![vec![array_type, DataType::Int64]])
}
_ => Ok(vec![vec![]]),
}
Expand Down
15 changes: 9 additions & 6 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use arrow::datatypes::{DataType, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue,
Expand Down Expand Up @@ -590,17 +589,21 @@ fn coerce_arguments_for_fun(
if expressions.is_empty() {
return Ok(vec![]);
}

let mut expressions: Vec<Expr> = expressions.to_vec();

// coerce the fixed size list to list for all array fucntions
if fun.name().contains("array") {
// Cast Fixedsizelist to List for array functions
if *fun == BuiltinScalarFunction::MakeArray {
expressions = expressions
.into_iter()
.map(|expr| {
let data_type = expr.get_type(schema).unwrap();
let to_type = coerced_fixed_size_list_to_list(&data_type);
expr.cast_to(&to_type, schema)
if let DataType::FixedSizeList(field, _) = data_type {
let field = field.as_ref().clone();
let to_type = DataType::List(Arc::new(field));
expr.cast_to(&to_type, schema)
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;
}
Expand Down

0 comments on commit b1d79ba

Please sign in to comment.