Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Aug 19, 2023
1 parent f1ef412 commit 77b014c
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 54 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1958,7 +1958,7 @@ mod tests {

let result = get_updated_right_ordering_equivalence_properties(
&join_type,
&[right_oeq_classes.clone()],
&[right_oeq_classes],
left_columns_len,
&join_eq_properties,
)?;
Expand Down
143 changes: 96 additions & 47 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use std::sync::Arc;

use arrow::datatypes::{DataType, IntervalUnit};
use arrow::datatypes::{DataType, Field, IntervalUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
Expand Down Expand Up @@ -543,6 +543,90 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

// TODO: Add this function to arrow-rs
fn get_list_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => get_list_base_type(field.data_type()),
base_type => Ok(base_type.clone()),
},

_ => Err(DataFusionError::Internal(
"Only List type is supported".to_string(),
)),
}
}

fn coerce_nulls_for_array_append(
expressions: Vec<Expr>,
schema: &DFSchema,
) -> Result<Vec<Expr>> {
assert_eq!(expressions.len(), 2);

let data_types: Result<Vec<_>> =
expressions.iter().map(|e| e.get_type(schema)).collect();
let data_types = data_types?;

if data_types[1] == DataType::Null {
let to_type = get_list_base_type(&data_types[0])?;
let arg1 = lit(ScalarValue::try_from(to_type)?);
return Ok(vec![expressions[0].clone(), arg1]);
}

if let DataType::List(ref field) = data_types[0] {
if field.data_type() == &DataType::Null {
let arg0 = cast_array_expr(
&expressions[0],
&data_types[0],
&DataType::List(Arc::new(Field::new(
field.name(),
data_types[1].clone(),
field.is_nullable(),
))),
schema,
)?;
return Ok(vec![arg0, expressions[1].clone()]);
}
}

Ok(expressions)
}

fn coerce_nulls_for_array_prepend(
expressions: Vec<Expr>,
schema: &DFSchema,
) -> Result<Vec<Expr>> {
assert_eq!(expressions.len(), 2);

let data_types: Result<Vec<_>> =
expressions.iter().map(|e| e.get_type(schema)).collect();
let data_types = data_types?;

if data_types[0] == DataType::Null {
let to_type = get_list_base_type(&data_types[1])?;
let arg0 = lit(ScalarValue::try_from(to_type)?);
return Ok(vec![arg0, expressions[1].clone()]);
}

if let DataType::List(ref field) = data_types[1] {
if field.data_type() == &DataType::Null {
let arg1 = cast_array_expr(
&expressions[1],
&data_types[1],
&DataType::List(Arc::new(Field::new(
field.name(),
data_types[0].clone(),
field.is_nullable(),
))),
schema,
)?;
return Ok(vec![expressions[0].clone(), arg1]);
}
}

Ok(expressions)
}

fn coerce_arguments_for_fun(
expressions: &[Expr],
schema: &DFSchema,
Expand Down Expand Up @@ -584,59 +668,24 @@ fn coerce_arguments_for_fun(
.fold(current_types.first().unwrap().clone(), |acc, x| {
comparison_coercion(&acc, x).unwrap_or(acc)
});

return expressions
.iter()
.zip(current_types)
.map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema))
.collect();
.collect::<Result<Vec<_>>>();
}

// Represent NULL as element
// Iterate once to get non-null type
// Convert null type to non-null type with None
// i.e. ScalarValue::Int64(None)

let data_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let mut found_null = false;
// Assume that all the non-null types are the same
let mut first_non_null: Option<DataType> = None;

for data_type in data_types.iter() {
if *data_type == DataType::Null {
found_null = true;
} else if first_non_null.is_none() {
first_non_null = Some(data_type.clone());
// Convert Null to ScalarValue
// If data_type is Int64, we will convert it to ScalarValue::Int64(None)
match fun {
BuiltinScalarFunction::ArrayAppend => {
coerce_nulls_for_array_append(expressions, schema)
}
}

if found_null {
let mut expressions = expressions;
match first_non_null {
Some(DataType::List(field)) => {
let arr_val_type = field.data_type().clone();
for expr in expressions.iter_mut() {
if expr.get_type(schema)? == DataType::Null {
*expr = lit(ScalarValue::try_from(arr_val_type.clone())?);
}
}
}
Some(data_type) => {
for expr in expressions.iter_mut() {
if expr.get_type(schema)? == DataType::Null {
*expr = lit(ScalarValue::try_from(data_type.clone())?);
}
}
}
None => {}
BuiltinScalarFunction::ArrayPrepend => {
coerce_nulls_for_array_prepend(expressions, schema)
}
Ok(expressions)
} else {
Ok(expressions)

_ => Ok(expressions),
}
}

Expand All @@ -653,7 +702,7 @@ fn cast_array_expr(
schema: &DFSchema,
) -> Result<Expr> {
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
ScalarValue::try_from(to_type.clone()).map(lit)
} else {
cast_expr(expr, to_type, schema)
}
Expand Down
34 changes: 28 additions & 6 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h'
----
[1] [h, e]

# array_slice scalar function #13 (with negative number and NULL)
# array_slice scalar function #13 (with positive number and NULL)
query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);

Expand Down Expand Up @@ -864,13 +864,23 @@ select make_array(['a','b'], null);

## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)

# TODO: array_append with NULLs
# array_append scalar function #1
# query ?
# select array_append(make_array(), 4);
# ----
# [4]
query ?
select array_append(make_array(null), 4);
----
[, 4]

query ?
select array_append(make_array(1, 2, null), 4);
----
[1, 2, , 4]

query ?
select array_append(make_array(), 4);
----
[4]

# TODO: array_append with NULLs
# array_append scalar function #2
# query ??
# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4));
Expand Down Expand Up @@ -965,6 +975,18 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma
# ----
# [4]

query ?
select array_prepend(4, make_array());
----
[4]

query ?
select array_prepend(4, make_array(null));
----
[4, ]



# array_prepend scalar function #2
# query ??
# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array());
Expand Down

0 comments on commit 77b014c

Please sign in to comment.