From 77b014ca781eaa788bc5604f9808885ec93ed346 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 19 Aug 2023 14:23:20 +0800 Subject: [PATCH] cleanup Signed-off-by: jayzhan211 --- .../core/src/physical_plan/joins/utils.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 143 ++++++++++++------ datafusion/sqllogictest/test_files/array.slt | 34 ++++- 3 files changed, 125 insertions(+), 54 deletions(-) diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index dd23e39e2774f..0b4a30da30f5c 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -1958,7 +1958,7 @@ mod tests { let result = get_updated_right_ordering_equivalence_properties( &join_type, - &[right_oeq_classes.clone()], + &[right_oeq_classes], left_columns_len, &join_eq_properties, )?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6d2a7ea2d983e..8ffc02520cf2e 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit}; +use arrow::datatypes::{DataType, Field, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; @@ -543,6 +543,90 @@ fn coerce_arguments_for_signature( .collect::>>() } +// TODO: Add this function to arrow-rs +fn get_list_base_type(data_type: &DataType) -> Result { + match data_type { + DataType::List(field) => match field.data_type() { + DataType::List(_) => get_list_base_type(field.data_type()), + base_type => Ok(base_type.clone()), + }, + + _ => Err(DataFusionError::Internal( + "Only List type is supported".to_string(), + )), + } +} + +fn coerce_nulls_for_array_append( + expressions: Vec, + schema: &DFSchema, +) -> Result> { + assert_eq!(expressions.len(), 2); + + let data_types: Result> = + expressions.iter().map(|e| e.get_type(schema)).collect(); + let data_types = data_types?; + + if data_types[1] == DataType::Null { + let to_type = get_list_base_type(&data_types[0])?; + let arg1 = lit(ScalarValue::try_from(to_type)?); + return Ok(vec![expressions[0].clone(), arg1]); + } + + if let DataType::List(ref field) = data_types[0] { + if field.data_type() == &DataType::Null { + let arg0 = cast_array_expr( + &expressions[0], + &data_types[0], + &DataType::List(Arc::new(Field::new( + field.name(), + data_types[1].clone(), + field.is_nullable(), + ))), + schema, + )?; + return Ok(vec![arg0, expressions[1].clone()]); + } + } + + Ok(expressions) +} + +fn coerce_nulls_for_array_prepend( + expressions: Vec, + schema: &DFSchema, +) -> Result> { + assert_eq!(expressions.len(), 2); + + let data_types: Result> = + expressions.iter().map(|e| e.get_type(schema)).collect(); + let data_types = data_types?; + + if data_types[0] == DataType::Null { + let to_type = get_list_base_type(&data_types[1])?; + let arg0 = lit(ScalarValue::try_from(to_type)?); + return Ok(vec![arg0, expressions[1].clone()]); + } + + if let DataType::List(ref field) = data_types[1] { + if field.data_type() == &DataType::Null { + let arg1 = cast_array_expr( + &expressions[1], + &data_types[1], + &DataType::List(Arc::new(Field::new( + field.name(), + data_types[0].clone(), + field.is_nullable(), + ))), + schema, + )?; + return Ok(vec![expressions[0].clone(), arg1]); + } + } + + Ok(expressions) +} + fn coerce_arguments_for_fun( expressions: &[Expr], schema: &DFSchema, @@ -584,59 +668,24 @@ fn coerce_arguments_for_fun( .fold(current_types.first().unwrap().clone(), |acc, x| { comparison_coercion(&acc, x).unwrap_or(acc) }); - return expressions .iter() .zip(current_types) .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); + .collect::>>(); } - // Represent NULL as element - // Iterate once to get non-null type - // Convert null type to non-null type with None - // i.e. ScalarValue::Int64(None) - - let data_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let mut found_null = false; - // Assume that all the non-null types are the same - let mut first_non_null: Option = None; - - for data_type in data_types.iter() { - if *data_type == DataType::Null { - found_null = true; - } else if first_non_null.is_none() { - first_non_null = Some(data_type.clone()); + // Convert Null to ScalarValue + // If data_type is Int64, we will convert it to ScalarValue::Int64(None) + match fun { + BuiltinScalarFunction::ArrayAppend => { + coerce_nulls_for_array_append(expressions, schema) } - } - - if found_null { - let mut expressions = expressions; - match first_non_null { - Some(DataType::List(field)) => { - let arr_val_type = field.data_type().clone(); - for expr in expressions.iter_mut() { - if expr.get_type(schema)? == DataType::Null { - *expr = lit(ScalarValue::try_from(arr_val_type.clone())?); - } - } - } - Some(data_type) => { - for expr in expressions.iter_mut() { - if expr.get_type(schema)? == DataType::Null { - *expr = lit(ScalarValue::try_from(data_type.clone())?); - } - } - } - None => {} + BuiltinScalarFunction::ArrayPrepend => { + coerce_nulls_for_array_prepend(expressions, schema) } - Ok(expressions) - } else { - Ok(expressions) + + _ => Ok(expressions), } } @@ -653,7 +702,7 @@ fn cast_array_expr( schema: &DFSchema, ) -> Result { if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) + ScalarValue::try_from(to_type.clone()).map(lit) } else { cast_expr(expr, to_type, schema) } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 98ffa6b18b566..e1bbecde3027e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -750,7 +750,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' ---- [1] [h, e] -# array_slice scalar function #13 (with negative number and NULL) +# array_slice scalar function #13 (with positive number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); @@ -864,13 +864,23 @@ select make_array(['a','b'], null); ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs # array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); -# ---- -# [4] +query ? +select array_append(make_array(null), 4); +---- +[, 4] +query ? +select array_append(make_array(1, 2, null), 4); +---- +[1, 2, , 4] + +query ? +select array_append(make_array(), 4); +---- +[4] + +# TODO: array_append with NULLs # array_append scalar function #2 # query ?? # select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); @@ -965,6 +975,18 @@ select array_append(column1, make_array(1, 11, 111)), array_append(make_array(ma # ---- # [4] +query ? +select array_prepend(4, make_array()); +---- +[4] + +query ? +select array_prepend(4, make_array(null)); +---- +[4, ] + + + # array_prepend scalar function #2 # query ?? # select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array());