Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nulls and empty for array functions #7338

Closed
wants to merge 14 commits into from
Prev Previous commit
Next Next commit
cleanup
Signed-off-by: jayzhan211 <[email protected]>
jayzhan211 committed Aug 25, 2023
commit ad84d014a8717f9b72da5580620fbd63c6b9176b
143 changes: 96 additions & 47 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
use std::sync::Arc;

use arrow::datatypes::{DataType, IntervalUnit};
use arrow::datatypes::{DataType, Field, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
@@ -553,6 +553,90 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

// TODO: Add this function to arrow-rs
fn get_list_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => get_list_base_type(field.data_type()),
base_type => Ok(base_type.clone()),
},

_ => Err(DataFusionError::Internal(
"Only List type is supported".to_string(),
)),
}
}

fn coerce_nulls_for_array_append(
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
expressions: Vec<Expr>,
schema: &DFSchema,
) -> Result<Vec<Expr>> {
assert_eq!(expressions.len(), 2);

let data_types: Result<Vec<_>> =
expressions.iter().map(|e| e.get_type(schema)).collect();
let data_types = data_types?;

if data_types[1] == DataType::Null {
let to_type = get_list_base_type(&data_types[0])?;
let arg1 = lit(ScalarValue::try_from(to_type)?);
return Ok(vec![expressions[0].clone(), arg1]);
}

if let DataType::List(ref field) = data_types[0] {
if field.data_type() == &DataType::Null {
let arg0 = cast_array_expr(
&expressions[0],
&data_types[0],
&DataType::List(Arc::new(Field::new(
field.name(),
data_types[1].clone(),
field.is_nullable(),
))),
schema,
)?;
return Ok(vec![arg0, expressions[1].clone()]);
}
}

Ok(expressions)
}

fn coerce_nulls_for_array_prepend(
expressions: Vec<Expr>,
schema: &DFSchema,
) -> Result<Vec<Expr>> {
assert_eq!(expressions.len(), 2);

let data_types: Result<Vec<_>> =
expressions.iter().map(|e| e.get_type(schema)).collect();
let data_types = data_types?;

if data_types[0] == DataType::Null {
let to_type = get_list_base_type(&data_types[1])?;
let arg0 = lit(ScalarValue::try_from(to_type)?);
return Ok(vec![arg0, expressions[1].clone()]);
}

if let DataType::List(ref field) = data_types[1] {
if field.data_type() == &DataType::Null {
let arg1 = cast_array_expr(
&expressions[1],
&data_types[1],
&DataType::List(Arc::new(Field::new(
field.name(),
data_types[0].clone(),
field.is_nullable(),
))),
schema,
)?;
return Ok(vec![expressions[0].clone(), arg1]);
}
}

Ok(expressions)
}

fn coerce_arguments_for_fun(
expressions: &[Expr],
schema: &DFSchema,
@@ -594,59 +678,24 @@ fn coerce_arguments_for_fun(
.fold(current_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

return expressions
.iter()
.zip(current_types)
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema))
.collect();
.collect::<Result<Vec<_>>>();
}

// Represent NULL as element
// Iterate once to get non-null type
// Convert null type to non-null type with None
// i.e. ScalarValue::Int64(None)

let data_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let mut found_null = false;
// Assume that all the non-null types are the same
let mut first_non_null: Option<DataType> = None;

for data_type in data_types.iter() {
if *data_type == DataType::Null {
found_null = true;
} else if first_non_null.is_none() {
first_non_null = Some(data_type.clone());
// Convert Null to ScalarValue
// If data_type is Int64, we will convert it to ScalarValue::Int64(None)
match fun {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am inclined to solve this problem by expanding signature's structure because there is one difficulty with User Defined Function. For example, I am Arrow DataFusion's user and I want to define my own ArrayAppend implementation (the function new_array_append). And how this function would handle nulls?

What do you think about it, @alamb and @jayzhan211?

BuiltinScalarFunction::ArrayAppend => {
coerce_nulls_for_array_append(expressions, schema)
}
}

if found_null {
let mut expressions = expressions;
match first_non_null {
Some(DataType::List(field)) => {
let arr_val_type = field.data_type().clone();
for expr in expressions.iter_mut() {
if expr.get_type(schema)? == DataType::Null {
*expr = lit(ScalarValue::try_from(arr_val_type.clone())?);
}
}
}
Some(data_type) => {
for expr in expressions.iter_mut() {
if expr.get_type(schema)? == DataType::Null {
*expr = lit(ScalarValue::try_from(data_type.clone())?);
}
}
}
None => {}
BuiltinScalarFunction::ArrayPrepend => {
coerce_nulls_for_array_prepend(expressions, schema)
}
Ok(expressions)
} else {
Ok(expressions)

_ => Ok(expressions),
}
}

@@ -663,7 +712,7 @@ fn cast_array_expr(
schema: &DFSchema,
) -> Result<Expr> {
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
ScalarValue::try_from(to_type.clone()).map(lit)
} else {
cast_expr(expr, to_type, schema)
}
34 changes: 28 additions & 6 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
@@ -827,7 +827,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h'
----
[1] [h, e]

# array_slice scalar function #13 (with negative number and NULL)
# array_slice scalar function #13 (with positive number and NULL)
query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);

@@ -941,13 +941,23 @@ select make_array(['a','b'], null);

## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)

# TODO: array_append with NULLs
# array_append scalar function #1
# query ?
# select array_append(make_array(), 4);
# ----
# [4]
query ?
select array_append(make_array(null), 4);
----
[, 4]

query ?
select array_append(make_array(1, 2, null), 4);
----
[1, 2, , 4]

query ?
select array_append(make_array(), 4);
----
[4]

# TODO: array_append with NULLs
# array_append scalar function #2
# query ??
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
@@ -1042,6 +1052,18 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma
# ----
# [4]

query ?
select array_prepend(4, make_array());
----
[4]

query ?
select array_prepend(4, make_array(null));
----
[4, ]



# array_prepend scalar function #2
# query ??
# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array());