From d3f63728d222cc5cf30cf03a12ec9a0b41399b18 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 11 Jul 2024 07:32:03 +0800 Subject: [PATCH 01/14] Change `array_agg` to return `null` on no input rather than empty list (#11299) * change array agg semantic for empty result Signed-off-by: jayzhan211 * return null Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix order sensitive Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * add more test Signed-off-by: jayzhan211 * fix null Signed-off-by: jayzhan211 * fix multi-phase case Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix clone Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/common/src/scalar/mod.rs | 10 ++ datafusion/core/tests/dataframe/mod.rs | 2 +- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/expr/src/aggregate_function.rs | 2 +- .../physical-expr/src/aggregate/array_agg.rs | 17 +- .../src/aggregate/array_agg_distinct.rs | 11 +- .../src/aggregate/array_agg_ordered.rs | 12 +- .../physical-expr/src/aggregate/build_in.rs | 4 +- .../sqllogictest/test_files/aggregate.slt | 155 +++++++++++++----- 9 files changed, 161 insertions(+), 54 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index c8f21788cbbd..6c03e8698e80 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1984,6 +1984,16 @@ impl ScalarValue { Self::new_list(values, data_type, true) } + /// Create ListArray with Null with specific data type + /// + /// - new_null_list(i32, nullable, 1): `ListArray[NULL]` + pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self { + let data_type = DataType::List(Field::new_list_field(data_type, nullable).into()); + Self::List(Arc::new(ListArray::from(ArrayData::new_null( + &data_type, null_len, + )))) + } + /// Converts `IntoIterator` where each element has type corresponding to /// `data_type`, to a [`ListArray`]. /// diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2d1904d9e166..f1d57c44293b 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index e503b74992c3..86032dc9bc96 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, false), - false + true ),]) ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 23e98714dfa4..3cae78eaed9b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -118,7 +118,7 @@ impl AggregateFunction { pub fn nullable(&self) -> Result { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), - AggregateFunction::ArrayAgg => Ok(false), + AggregateFunction::ArrayAgg => Ok(true), } } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 634a0a017903..38a973802933 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -71,7 +71,7 @@ impl AggregateExpr for ArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -86,7 +86,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )]) } @@ -137,8 +137,11 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = Arc::clone(&values[0]); - self.values.push(val); + if val.len() > 0 { + self.values.push(val); + } Ok(()) } @@ -162,13 +165,15 @@ impl Accumulator for ArrayAggAccumulator { fn evaluate(&mut self) -> Result { // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable); - return Ok(ScalarValue::List(arr)); + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.nullable, + 1, + )); } let concated_array = arrow::compute::concat(&element_arrays)?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index a59d85e84a20..368d11d7421a 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -75,7 +75,7 @@ impl AggregateExpr for DistinctArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -90,7 +90,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )]) } @@ -165,6 +165,13 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.nullable, + 1, + )); + } let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a64d97637c3b..d44811192f66 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -92,7 +92,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), self.nullable), - false, + true, )) } @@ -111,7 +111,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), self.nullable), - false, // This should be the same as field() + true, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( @@ -309,6 +309,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&mut self) -> Result { + if self.values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatypes[0].clone(), + self.nullable, + 1, + )); + } + let values = self.values.clone(); let array = if self.reverse { ScalarValue::new_list_from_iter( diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index d4cd3d51d174..68c9b4859f1f 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -147,7 +147,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - false, + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -167,7 +167,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - false, + true, ), result_agg_phy_exprs.field().unwrap() ); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e891093c8156..7dd1ea82b327 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1694,7 +1694,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT query ? SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test ---- -[] +NULL # csv_query_array_agg_one query ? @@ -1753,31 +1753,12 @@ NULL 4 29 1.260869565217 123 -117 23 NULL 5 -194 -13.857142857143 118 -101 14 NULL NULL 781 7.81 125 -117 100 -# TODO: array_agg_distinct output is non-deterministic -- rewrite with array_sort(list_sort) -# unnest is also not available, so manually unnesting via CROSS JOIN -# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data -# +# select with count to forces array_agg_distinct function, since single distinct expression is converted to group by by optimizer # csv_query_array_agg_distinct -query III -WITH indices AS ( - SELECT 1 AS idx UNION ALL - SELECT 2 AS idx UNION ALL - SELECT 3 AS idx UNION ALL - SELECT 4 AS idx UNION ALL - SELECT 5 AS idx -) -SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy -FROM ( - SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100 -) data - CROSS JOIN indices -ORDER BY 1 ----- -1 5 100 -2 5 100 -3 5 100 -4 5 100 -5 5 100 +query ?I +SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100 +---- +[1, 2, 3, 4, 5] 100 # aggregate_time_min_and_max query TT @@ -2732,6 +2713,16 @@ SELECT COUNT(DISTINCT c1) FROM test # TODO: aggregate_with_alias +# test_approx_percentile_cont_decimal_support +query TI +SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 4 +b 5 +c 4 +d 4 +e 4 + # array_agg_zero query ? SELECT ARRAY_AGG([]) @@ -2744,28 +2735,114 @@ SELECT ARRAY_AGG([1]) ---- [[1]] -# test_approx_percentile_cont_decimal_support -query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +# test array_agg with no row qualified +statement ok +create table t(a int, b float, c bigint) as values (1, 1.2, 2); + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(a) from t where a > 2; ---- -a 4 -b 5 -c 4 -d 4 -e 4 +NULL +query ? +select array_agg(b) from t where b > 3.1; +---- +NULL -# array_agg_zero query ? -SELECT ARRAY_AGG([]); +select array_agg(c) from t where c > 3; ---- -[[]] +NULL -# array_agg_one +query ?I +select array_agg(c), count(1) from t where c > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour query ? -SELECT ARRAY_AGG([1]); +select array_agg(a) from t where a > 3 group by a; ---- -[[1]] + +query ?I +select array_agg(a), count(1) from t where a > 3 group by a; +---- + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t where a > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(distinct a), count(1) from t where a > 3 group by a; +---- + +# test order sensitive array agg +query ? +select array_agg(a order by a) from t where a > 3; +---- +NULL + +query ? +select array_agg(a order by a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(a order by a), count(1) from t where a > 3 group by a; +---- + +statement ok +drop table t; + +# test with no values +statement ok +create table t(a int, b float, c bigint); + +query ? +select array_agg(a) from t; +---- +NULL + +query ? +select array_agg(b) from t; +---- +NULL + +query ? +select array_agg(c) from t; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +statement ok +drop table t; + # array_agg_i32 statement ok From 7a23ea9bce32dc8ae195caa8ca052673031c06c9 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 11 Jul 2024 09:38:15 +0800 Subject: [PATCH 02/14] Minor: return "not supported" for `COUNT DISTINCT` with multiple arguments (#11391) * Minor: return "not supported" for COUNT DISTINCT with multiple arguments * update condition --- datafusion/functions-aggregate/src/count.rs | 6 +++++- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bd0155df0271..0a667d35dce5 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -37,7 +37,7 @@ use arrow::{ buffer::BooleanBuffer, }; use datafusion_common::{ - downcast_value, internal_err, DataFusionError, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ @@ -138,6 +138,10 @@ impl AggregateUDFImpl for Count { return Ok(Box::new(CountAccumulator::new())); } + if acc_args.input_exprs.len() > 1 { + return not_impl_err!("COUNT DISTINCT with multiple arguments"); + } + let data_type = acc_args.input_type; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7dd1ea82b327..6fafc0a74110 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2019,6 +2019,10 @@ SELECT count(c1, c2) FROM test ---- 3 +# count(distinct) with multiple arguments +query error DataFusion error: This feature is not implemented: COUNT DISTINCT with multiple arguments +SELECT count(distinct c1, c2) FROM test + # count_null query III SELECT count(null), count(null, null), count(distinct null) FROM test From 2413155a3ed808285e31421a8b6aac23b8abdb91 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 11 Jul 2024 08:56:47 -0600 Subject: [PATCH 03/14] feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400) * update tests * update tests * add rustdoc * update PartialEq impl * fix * address feedback about improving api --- datafusion/core/src/physical_planner.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 126 +++++++++++++++++- 2 files changed, 121 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6aad4d575532..d2bc334ec324 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2312,7 +2312,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2551,7 +2551,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c153ead9639f..c34dcdfb7598 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -53,6 +53,8 @@ pub struct BinaryExpr { left: Arc, op: Operator, right: Arc, + /// Specifies whether an error is returned on overflow or not + fail_on_overflow: bool, } impl BinaryExpr { @@ -62,7 +64,22 @@ impl BinaryExpr { op: Operator, right: Arc, ) -> Self { - Self { left, op, right } + Self { + left, + op, + right, + fail_on_overflow: false, + } + } + + /// Create new binary expression with explicit fail_on_overflow value + pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self { + Self { + left: self.left, + op: self.op, + right: self.right, + fail_on_overflow, + } } /// Get the left side of the binary expression @@ -273,8 +290,11 @@ impl PhysicalExpr for BinaryExpr { } match self.op { + Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), + Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul), Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), @@ -327,11 +347,10 @@ impl PhysicalExpr for BinaryExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(BinaryExpr::new( - Arc::clone(&children[0]), - self.op, - Arc::clone(&children[1]), - ))) + Ok(Arc::new( + BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1])) + .with_fail_on_overflow(self.fail_on_overflow), + )) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { @@ -496,7 +515,12 @@ impl PartialEq for BinaryExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.left.eq(&x.left) && self.op == x.op && self.right.eq(&x.right)) + .map(|x| { + self.left.eq(&x.left) + && self.op == x.op + && self.right.eq(&x.right) + && self.fail_on_overflow.eq(&x.fail_on_overflow) + }) .unwrap_or(false) } } @@ -661,6 +685,7 @@ mod tests { use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; + use datafusion_physical_expr_common::expressions::column::Column; /// Performs a binary operation, applying any type coercion necessary fn binary_op( @@ -4008,4 +4033,91 @@ mod tests { .unwrap(); assert_eq!(&casted, &dictionary); } + + #[test] + fn test_add_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Plus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 + 1")); + Ok(()) + } + + #[test] + fn test_subtract_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MIN])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Minus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: -2147483648 - 1")); + Ok(()) + } + + #[test] + fn test_mul_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 2])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Multiply, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 * 2")); + Ok(()) + } } From ed65c11065f74d72995619450d5325234aba0b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 11 Jul 2024 22:58:20 +0800 Subject: [PATCH 04/14] Enable clone_on_ref_ptr clippy lint on sql (#11380) --- datafusion/sql/examples/sql.rs | 2 +- datafusion/sql/src/cte.rs | 2 +- datafusion/sql/src/expr/mod.rs | 2 +- datafusion/sql/src/lib.rs | 2 ++ datafusion/sql/src/statement.rs | 4 ++-- datafusion/sql/tests/common/mod.rs | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index aee4cf5a38ed..1b92a7e116b1 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -119,7 +119,7 @@ fn create_table_source(fields: Vec) -> Arc { impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 0035dcda6ed7..3dfe00e3c5e0 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -144,7 +144,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // as the input to the recursive term let work_table_plan = LogicalPlanBuilder::scan( cte_name.to_string(), - work_table_source.clone(), + Arc::clone(&work_table_source), None, )? .build()?; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0546a101fcb2..859842e212be 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -981,7 +981,7 @@ mod tests { impl ContextProvider for TestContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 1040cc61c702..eb5fec7a3c8b 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! This module provides: //! diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 6cdb2f959cd8..1acfac79acc0 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -870,12 +870,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.context_provider.get_table_source(table_ref.clone())?; let plan = LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { let plan = self.query_to_plan(query, &mut PlannerContext::new())?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, None) } }; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index f5caaefb3ea0..b8d8bd12d28b 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -258,6 +258,6 @@ impl TableSource for EmptyTable { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } } From 0b2eb50c0f980562a6c009f541c4dbd5831b5fe1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 11 Jul 2024 10:58:53 -0400 Subject: [PATCH 05/14] Move configuration information out of example usage page (#11300) --- datafusion/core/src/lib.rs | 6 + docs/source/index.rst | 8 +- docs/source/library-user-guide/index.md | 21 ++- docs/source/user-guide/crate-configuration.md | 146 ++++++++++++++++++ docs/source/user-guide/example-usage.md | 129 ---------------- 5 files changed, 177 insertions(+), 133 deletions(-) create mode 100644 docs/source/user-guide/crate-configuration.md diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index f5805bc06982..63dbe824c231 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -620,6 +620,12 @@ doc_comment::doctest!( user_guide_example_usage ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/crate-configuration.md", + user_guide_crate_configuration +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/configs.md", diff --git a/docs/source/index.rst b/docs/source/index.rst index d491df04f7fe..8fbff208f561 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,13 +41,16 @@ DataFusion offers SQL and Dataframe APIs, excellent CSV, Parquet, JSON, and Avro, extensive customization, and a great community. -To get started with examples, see the `example usage`_ section of the user guide and the `datafusion-examples`_ directory. +To get started, see -See the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. +* The `example usage`_ section of the user guide and the `datafusion-examples`_ directory. +* The `library user guide`_ for examples of using DataFusion's extension APIs +* The `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _library user guide: library-user-guide/index.html .. _communication: contributor-guide/communication.html .. _toc.asf-links: @@ -80,6 +83,7 @@ See the `developer’s guide`_ for contributing and `communication`_ for getting user-guide/introduction user-guide/example-usage + user-guide/crate-configuration user-guide/cli/index user-guide/dataframe user-guide/expressions diff --git a/docs/source/library-user-guide/index.md b/docs/source/library-user-guide/index.md index 47257e0c926e..fd126a1120ed 100644 --- a/docs/source/library-user-guide/index.md +++ b/docs/source/library-user-guide/index.md @@ -19,8 +19,25 @@ # Introduction -The library user guide explains how to use the DataFusion library as a dependency in your Rust project. Please check out the user-guide for more details on how to use DataFusion's SQL and DataFrame APIs, or the contributor guide for details on how to contribute to DataFusion. +The library user guide explains how to use the DataFusion library as a +dependency in your Rust project and customize its behavior using its extension APIs. -If you haven't reviewed the [architecture section in the docs][docs], it's a useful place to get the lay of the land before starting down a specific path. +Please check out the [user guide] for getting started using +DataFusion's SQL and DataFrame APIs, or the [contributor guide] +for details on how to contribute to DataFusion. +If you haven't reviewed the [architecture section in the docs][docs], it's a +useful place to get the lay of the land before starting down a specific path. + +DataFusion is designed to be extensible at all points, including + +- [x] User Defined Functions (UDFs) +- [x] User Defined Aggregate Functions (UDAFs) +- [x] User Defined Table Source (`TableProvider`) for tables +- [x] User Defined `Optimizer` passes (plan rewrites) +- [x] User Defined `LogicalPlan` nodes +- [x] User Defined `ExecutionPlan` nodes + +[user guide]: ../user-guide/example-usage.md +[contributor guide]: ../contributor-guide/index.md [docs]: https://docs.rs/datafusion/latest/datafusion/#architecture diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md new file mode 100644 index 000000000000..0587d06a3919 --- /dev/null +++ b/docs/source/user-guide/crate-configuration.md @@ -0,0 +1,146 @@ + + +# Crate Configuration + +This section contains information on how to configure DataFusion in your Rust +project. See the [Configuration Settings] section for a list of options that +control DataFusion's behavior. + +[configuration settings]: configs.md + +## Add latest non published DataFusion dependency + +DataFusion changes are published to `crates.io` according to the [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) + +If you would like to test out DataFusion changes which are merged but not yet +published, Cargo supports adding dependency directly to GitHub branch: + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} +``` + +Also it works on the package level + +```toml +datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} +``` + +And with features + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } +``` + +More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) + +## Optimized Configuration + +For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is +worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. + +```toml +[dependencies] +datafusion = { version = "22.0" } +tokio = { version = "^1.0", features = ["rt-multi-thread"] } +snmalloc-rs = "0.3" + +[profile.release] +lto = true +codegen-units = 1 +``` + +Then, in `main.rs.` update the memory allocator with the below after your imports: + +```rust ,ignore +use datafusion::prelude::*; + +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + Ok(()) +} +``` + +Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally +with `native` or at least `avx2`. + +```shell +RUSTFLAGS='-C target-cpu=native' cargo run --release +``` + +## Enable backtraces + +By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, +like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: + +```toml +datafusion = { version = "31.0.0", features = ["backtrace"]} +``` + +Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) + +```bash +RUST_BACKTRACE=1 ./target/debug/datafusion-cli +DataFusion CLI v31.0.0 +> select row_numer() over (partition by a order by a) from (select 1 a); +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 + 3: std::backtrace::Backtrace::capture + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 + 4: datafusion_common::error::DataFusionError::get_back_trace + at /datafusion/datafusion/common/src/error.rs:436:30 + 5: datafusion_sql::expr::function::>::sql_function_to_expr + ............ +``` + +The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` + +``` +#[tokio::test] +async fn test_get_backtrace_for_failed_code() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = " + select row_numer() over (partition by a order by a) from (select 1 a); + "; + + let _ = ctx.sql(sql).await?.collect().await?; + + Ok(()) +} +``` + +To obtain a backtrace: + +```bash +cargo build --features=backtrace +RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture +``` + +Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 7dbd4045e75b..813dbb1bc02a 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -33,29 +33,6 @@ datafusion = "latest_version" tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` -## Add latest non published DataFusion dependency - -DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) -In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to GitHub branch - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} -``` - -Also it works on the package level - -```toml -datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} -``` - -And with features - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } -``` - -More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) - ## Run a SQL query against data stored in a CSV ```rust @@ -201,109 +178,3 @@ async fn main() -> datafusion::error::Result<()> { | 1 | 2 | +---+--------+ ``` - -## Extensibility - -DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: - -- [x] User Defined Functions (UDFs) -- [x] User Defined Aggregate Functions (UDAFs) -- [x] User Defined Table Source (`TableProvider`) for tables -- [x] User Defined `Optimizer` passes (plan rewrites) -- [x] User Defined `LogicalPlan` nodes -- [x] User Defined `ExecutionPlan` nodes - -## Optimized Configuration - -For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is -worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. - -```toml -[dependencies] -datafusion = { version = "22.0" } -tokio = { version = "^1.0", features = ["rt-multi-thread"] } -snmalloc-rs = "0.3" - -[profile.release] -lto = true -codegen-units = 1 -``` - -Then, in `main.rs.` update the memory allocator with the below after your imports: - -```rust ,ignore -use datafusion::prelude::*; - -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - Ok(()) -} -``` - -Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally -with `native` or at least `avx2`. - -```shell -RUSTFLAGS='-C target-cpu=native' cargo run --release -``` - -## Enable backtraces - -By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, -like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: - -```toml -datafusion = { version = "31.0.0", features = ["backtrace"]} -``` - -Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) - -```bash -RUST_BACKTRACE=1 ./target/debug/datafusion-cli -DataFusion CLI v31.0.0 -> select row_number() over (partition by a order by a) from (select 1 a); -Error during planning: Invalid function 'row_number'. -Did you mean 'ROW_NUMBER'? - -backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 - 1: std::backtrace_rs::backtrace::trace_unsynchronized - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 - 2: std::backtrace::Backtrace::create - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 - 3: std::backtrace::Backtrace::capture - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 - 4: datafusion_common::error::DataFusionError::get_back_trace - at /datafusion/datafusion/common/src/error.rs:436:30 - 5: datafusion_sql::expr::function::>::sql_function_to_expr - ............ -``` - -The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` - -``` -#[tokio::test] -async fn test_get_backtrace_for_failed_code() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = " - select row_number() over (partition by a order by a) from (select 1 a); - "; - - let _ = ctx.sql(sql).await?.collect().await?; - - Ok(()) -} -``` - -To obtain a backtrace: - -```bash -cargo build --features=backtrace -RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture -``` - -Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored From faa1e98fc4bec6040c8de07d6c19973e572ad62d Mon Sep 17 00:00:00 2001 From: Arttu Date: Thu, 11 Jul 2024 18:07:53 +0200 Subject: [PATCH 06/14] reuse a single function to create the tpch test contexts (#11396) --- .../tests/cases/consumer_integration.rs | 207 ++++++------------ 1 file changed, 62 insertions(+), 145 deletions(-) diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 6133c239873b..10c1319b903b 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -32,151 +32,22 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; - async fn register_csv( - ctx: &SessionContext, - table_name: &str, - file_path: &str, - ) -> Result<()> { - ctx.register_csv(table_name, file_path, CsvReadOptions::default()) - .await - } - - async fn create_context_tpch1() -> Result { - let ctx = SessionContext::new(); - register_csv( - &ctx, - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - ) - .await?; - Ok(ctx) - } - - async fn create_context_tpch2() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch3() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch4() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch5() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), - ("NATION", "tests/testdata/tpch/nation.csv"), - ("REGION", "tests/testdata/tpch/region.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - - async fn create_context_tpch6() -> Result { - let ctx = SessionContext::new(); - - let registrations = - vec![("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv")]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - - Ok(ctx) - } - // missing context for query 7,8,9 - - async fn create_context_tpch10() -> Result { + async fn create_context(files: Vec<(&str, &str)>) -> Result { let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; + for (table_name, file_path) in files { + ctx.register_csv(table_name, file_path, CsvReadOptions::default()) + .await?; } - - Ok(ctx) - } - - async fn create_context_tpch11() -> Result { - let ctx = SessionContext::new(); - - let registrations = vec![ - ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), - ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), - ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), - ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), - ]; - - for (table_name, file_path) in registrations { - register_csv(&ctx, table_name, file_path).await?; - } - Ok(ctx) } #[tokio::test] async fn tpch_test_1() -> Result<()> { - let ctx = create_context_tpch1().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -200,7 +71,18 @@ mod tests { #[tokio::test] async fn tpch_test_2() -> Result<()> { - let ctx = create_context_tpch2().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_2.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -242,7 +124,12 @@ mod tests { #[tokio::test] async fn tpch_test_3() -> Result<()> { - let ctx = create_context_tpch3().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_3.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -267,7 +154,11 @@ mod tests { #[tokio::test] async fn tpch_test_4() -> Result<()> { - let ctx = create_context_tpch4().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_4.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -289,7 +180,15 @@ mod tests { #[tokio::test] async fn tpch_test_5() -> Result<()> { - let ctx = create_context_tpch5().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), + ("NATION", "tests/testdata/tpch/nation.csv"), + ("REGION", "tests/testdata/tpch/region.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_5.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -319,7 +218,11 @@ mod tests { #[tokio::test] async fn tpch_test_6() -> Result<()> { - let ctx = create_context_tpch6().await?; + let ctx = create_context(vec![( + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + )]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_6.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -338,7 +241,13 @@ mod tests { // TODO: missing plan 7, 8, 9 #[tokio::test] async fn tpch_test_10() -> Result<()> { - let ctx = create_context_tpch10().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_10.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -365,7 +274,15 @@ mod tests { #[tokio::test] async fn tpch_test_11() -> Result<()> { - let ctx = create_context_tpch11().await?; + let ctx = create_context(vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"), + ]) + .await?; let path = "tests/testdata/tpch_substrait_plans/query_11.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), From 6692382f22f04542534bba0183cf0682fd932da1 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Thu, 11 Jul 2024 18:17:03 +0200 Subject: [PATCH 07/14] refactor: change error type for "no statement" (#11411) Amends #11394 (sorry, I should have reviewed that). While reporting "not implemented" for "multiple statements" seems reasonable, I think the user should get a plan error (which roughly translates to "invalid argument") if they don't provide any statement. I don't see any reasonable way to support "no statement" ever, hence "not implemented" seems like a wrong promise. --- datafusion/core/src/execution/session_state.rs | 4 +--- datafusion/core/tests/sql/sql_api.rs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 60745076c242..dbfba9ea9352 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -554,9 +554,7 @@ impl SessionState { ); } let statement = statements.pop_front().ok_or_else(|| { - DataFusionError::NotImplemented( - "No SQL statements were provided in the query string".to_string(), - ) + plan_datafusion_err!("No SQL statements were provided in the query string") })?; Ok(statement) } diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index e7c40d2c8aa8..48f4a66b65dc 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -124,7 +124,7 @@ async fn empty_statement_returns_error() { let plan_res = state.create_logical_plan("").await; assert_eq!( plan_res.unwrap_err().strip_backtrace(), - "This feature is not implemented: No SQL statements were provided in the query string" + "Error during planning: No SQL statements were provided in the query string" ); } From f284e3bb73e089abc0c06b3314014522411bf1da Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:17:09 -0500 Subject: [PATCH 08/14] feat: add UDF to_local_time() (#11347) * feat: add UDF `to_local_time()` * chore: support column value in array * chore: lint * chore: fix conversion for us, ms, and s * chore: add more tests for daylight savings time * chore: add function description * refactor: update tests and add examples in description * chore: add description and example * chore: doc chore: doc chore: doc chore: doc chore: doc * chore: stop copying * chore: fix typo * chore: mention that the offset varies based on daylight savings time * refactor: parse timezone once and update examples in description * refactor: replace map..concat with flat_map * chore: add hard code timestamp value in test chore: doc chore: doc * chore: handle errors and remove panics * chore: move some test to slt * chore: clone time_value * chore: typo --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/datetime/mod.rs | 11 +- .../functions/src/datetime/to_local_time.rs | 564 ++++++++++++++++++ .../sqllogictest/test_files/timestamps.slt | 177 ++++++ 3 files changed, 751 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions/src/datetime/to_local_time.rs diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 9c2f80856bf8..a7e9827d6ca6 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -32,6 +32,7 @@ pub mod make_date; pub mod now; pub mod to_char; pub mod to_date; +pub mod to_local_time; pub mod to_timestamp; pub mod to_unixtime; @@ -50,6 +51,7 @@ make_udf_function!( make_udf_function!(now::NowFunc, NOW, now); make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); make_udf_function!( @@ -108,7 +110,13 @@ pub mod expr_fn { ),( now, "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", - ),( + ), + ( + to_local_time, + "converts a timezone-aware timestamp to local time (with no offset or timezone information), i.e. strips off the timezone from the timestamp", + args, + ), + ( to_unixtime, "converts a string and optional formats to a Unixtime", args, @@ -277,6 +285,7 @@ pub fn functions() -> Vec> { now(), to_char(), to_date(), + to_local_time(), to_unixtime(), to_timestamp(), to_timestamp_seconds(), diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs new file mode 100644 index 000000000000..c84d1015bd7e --- /dev/null +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -0,0 +1,564 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; +use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::{ + ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow::datatypes::{ + TimeUnit, + TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}, +}; + +use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, +}; + +/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or +/// timezone information). In other words, this function strips off the timezone from the timestamp, +/// while keep the display value of the timestamp the same. +#[derive(Debug)] +pub struct ToLocalTimeFunc { + signature: Signature, +} + +impl Default for ToLocalTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl ToLocalTimeFunc { + pub fn new() -> Self { + let base_sig = |array_type: TimeUnit| { + [ + Exact(vec![Timestamp(array_type, None)]), + Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .flat_map(base_sig) + .collect::>(); + + Self { + signature: Signature::one_of(full_sig, Volatility::Immutable), + } + } + + fn to_local_time(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {}", + args.len() + ); + } + + let time_value = &args[0]; + let arg_type = time_value.data_type(); + match arg_type { + DataType::Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + // If has timezone, adjust the underlying time value. The current time value + // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, + // we need to adjust the time value to the local time. See [`adjust_to_local_time`] + // for more details. + // + // Then remove the timezone in return type, i.e. return None + DataType::Timestamp(_, Some(timezone)) => { + let tz: Tz = timezone.parse()?; + + match time_value { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Array(array) => { + fn transform_array( + array: &ArrayRef, + tz: Tz, + ) -> Result { + let mut builder = PrimitiveBuilder::::new(); + + let primitive_array = as_primitive_array::(array)?; + for ts_opt in primitive_array.iter() { + match ts_opt { + None => builder.append_null(), + Some(ts) => { + let adjusted_ts: i64 = + adjust_to_local_time::(ts, tz)?; + builder.append_value(adjusted_ts) + } + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + + match array.data_type() { + Timestamp(_, None) => { + // if no timezone specificed, just return the input + Ok(time_value.clone()) + } + Timestamp(Nanosecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Microsecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Millisecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Second, Some(_)) => { + transform_array::(array, tz) + } + _ => { + exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type()) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + arg_type + ) + } + } + } +} + +/// This function converts a timestamp with a timezone to a timestamp without a timezone. +/// The display value of the adjusted timestamp remain the same, but the underlying timestamp +/// representation is adjusted according to the relative timezone offset to UTC. +/// +/// This function uses chrono to handle daylight saving time changes. +/// +/// For example, +/// +/// ```text +/// '2019-03-31T01:00:00Z'::timestamp at time zone 'Europe/Brussels' +/// ``` +/// +/// is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00+01:00 +/// ``` +/// +/// and is represented in DataFusion as: +/// +/// ```text +/// TimestampNanosecond(Some(1_553_990_400_000_000_000), Some("Europe/Brussels")) +/// ``` +/// +/// To strip off the timezone while keeping the display value the same, we need to +/// adjust the underlying timestamp with the timezone offset value using `adjust_to_local_time()` +/// +/// ```text +/// adjust_to_local_time(1_553_990_400_000_000_000, "Europe/Brussels") --> 1_553_994_000_000_000_000 +/// ``` +/// +/// The difference between `1_553_990_400_000_000_000` and `1_553_994_000_000_000_000` is +/// `3600_000_000_000` ns, which corresponds to 1 hour. This matches with the timezone +/// offset for "Europe/Brussels" for this date. +/// +/// Note that the offset varies with daylight savings time (DST), which makes this tricky! For +/// example, timezone "Europe/Brussels" has a 2-hour offset during DST and a 1-hour offset +/// when DST ends. +/// +/// Consequently, DataFusion can represent the timestamp in local time (with no offset or +/// timezone information) as +/// +/// ```text +/// TimestampNanosecond(Some(1_553_994_000_000_000_000), None) +/// ``` +/// +/// which is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00 +/// ``` +/// +/// See `test_adjust_to_local_time()` for example +fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { + fn convert_timestamp(ts: i64, converter: F) -> Result> + where + F: Fn(i64) -> MappedLocalTime>, + { + match converter(ts) { + MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( + "Ambiguous timestamp. Do you mean {:?} or {:?}", + earliest, + latest + ), + MappedLocalTime::None => exec_err!( + "The local time does not exist because there is a gap in the local time." + ), + MappedLocalTime::Single(date_time) => Ok(date_time), + } + } + + let date_time = match T::UNIT { + Nanosecond => Utc.timestamp_nanos(ts), + Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, + Millisecond => convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))?, + Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, + }; + + let offset_seconds: i64 = tz + .offset_from_utc_datetime(&date_time.naive_utc()) + .fix() + .local_minus_utc() as i64; + + let adjusted_date_time = date_time.add( + // This should not fail under normal circumstances as the + // maximum possible offset is 26 hours (93,600 seconds) + TimeDelta::try_seconds(offset_seconds) + .ok_or(DataFusionError::Internal("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000".to_string()))?, + ); + + // convert the naive datetime back to i64 + match T::UNIT { + Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or( + DataFusionError::Internal( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807".to_string(), + ), + ), + Microsecond => Ok(adjusted_date_time.timestamp_micros()), + Millisecond => Ok(adjusted_date_time.timestamp_millis()), + Second => Ok(adjusted_date_time.timestamp()), + } +} + +impl ScalarUDFImpl for ToLocalTimeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_local_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + arg_types.len() + ); + } + + match &arg_types[0] { + Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), + _ => exec_err!( + "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + args.len() + ); + } + + self.to_local_time(args) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; + use arrow::datatypes::{DataType, TimeUnit}; + use chrono::NaiveDateTime; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::{adjust_to_local_time, ToLocalTimeFunc}; + + #[test] + fn test_adjust_to_local_time() { + let timestamp_str = "2020-03-31T13:40:00"; + let tz: arrow::array::timezone::Tz = + "America/New_York".parse().expect("Invalid timezone"); + + let timestamp = timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + + let expected_timestamp = timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + + let res = adjust_to_local_time::(timestamp, tz).unwrap(); + assert_eq!(res, expected_timestamp); + } + + #[test] + fn test_to_local_time_scalar() { + let timezone = Some("Europe/Brussels".into()); + let timestamps_with_timezone = vec![ + ( + ScalarValue::TimestampNanosecond( + Some(1_123_123_000_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampNanosecond(Some(1_123_130_200_000_000_000), None), + ), + ( + ScalarValue::TimestampMicrosecond( + Some(1_123_123_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMicrosecond(Some(1_123_130_200_000_000), None), + ), + ( + ScalarValue::TimestampMillisecond( + Some(1_123_123_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMillisecond(Some(1_123_130_200_000), None), + ), + ( + ScalarValue::TimestampSecond(Some(1_123_123_000), timezone), + ScalarValue::TimestampSecond(Some(1_123_130_200), None), + ), + ]; + + for (input, expected) in timestamps_with_timezone { + test_to_local_time_helper(input, expected); + } + } + + #[test] + fn test_timezone_with_daylight_savings() { + let timezone_str = "America/New_York"; + let tz: arrow::array::timezone::Tz = + timezone_str.parse().expect("Invalid timezone"); + + // Test data: + // ( + // the string display of the input timestamp, + // the i64 representation of the timestamp before adjustment in nanosecond, + // the i64 representation of the timestamp after adjustment in nanosecond, + // ) + let test_cases = vec![ + ( + // DST time + "2020-03-31T13:40:00", + 1_585_676_400_000_000_000, + 1_585_662_000_000_000_000, + ), + ( + // End of DST + "2020-11-04T14:06:40", + 1_604_516_800_000_000_000, + 1_604_498_800_000_000_000, + ), + ]; + + for ( + input_timestamp_str, + expected_input_timestamp, + expected_adjusted_timestamp, + ) in test_cases + { + let input_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(input_timestamp, expected_input_timestamp); + + let expected_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(expected_timestamp, expected_adjusted_timestamp); + + let input = ScalarValue::TimestampNanosecond( + Some(input_timestamp), + Some(timezone_str.into()), + ); + let expected = + ScalarValue::TimestampNanosecond(Some(expected_timestamp), None); + test_to_local_time_helper(input, expected) + } + } + + fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let res = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Scalar(input)]) + .unwrap(); + match res { + ColumnarValue::Scalar(res) => { + assert_eq!(res, expected); + } + _ => panic!("unexpected return type"), + } + } + + #[test] + fn test_to_local_time_timezones_array() { + let cases = [ + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + None::>, + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + Some("+01:00".into()), + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ]; + + cases.iter().for_each(|(source, _tz_opt, expected)| { + let input = source + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let result = ToLocalTimeFunc::new() + .invoke(&[ColumnarValue::Array(Arc::new(input))]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + let left = arrow::array::cast::as_primitive_array::< + TimestampNanosecondType, + >(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } +} diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 2216dbfa5fd5..f4e492649b9f 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2844,3 +2844,180 @@ select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - query error select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +########## +## Test to_local_time function +########## + +# invalid number of arguments -- no argument +statement error +select to_local_time(); + +# invalid number of arguments -- more than 1 argument +statement error +select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); + +# invalid argument data type +statement error DataFusion error: Execution error: The to_local_time function can only accept timestamp as the arg, got Utf8 +select to_local_time('2024-04-01T00:00:20Z'); + +# invalid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "Europe/timezone": failed to parse timezone +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/timezone'); + +# valid query +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE '+05:00'); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); +---- +2024-04-01T00:00:20 + +query PTPT +select + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +from ( + select '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' as time +); +---- +2024-04-01T00:00:20+02:00 Timestamp(Nanosecond, Some("Europe/Brussels")) 2024-04-01T00:00:20 Timestamp(Nanosecond, None) + +# use to_local_time() in date_bin() +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')); +---- +2024-04-01T00:00:00 + +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels'; +---- +2024-04-01T00:00:00+02:00 + +# test using to_local_time() on array values +statement ok +create table t AS +VALUES + ('2024-01-01T00:00:01Z'), + ('2024-02-01T00:00:01Z'), + ('2024-03-01T00:00:01Z'), + ('2024-04-01T00:00:01Z'), + ('2024-05-01T00:00:01Z'), + ('2024-06-01T00:00:01Z'), + ('2024-07-01T00:00:01Z'), + ('2024-08-01T00:00:01Z'), + ('2024-09-01T00:00:01Z'), + ('2024-10-01T00:00:01Z'), + ('2024-11-01T00:00:01Z'), + ('2024-12-01T00:00:01Z') +; + +statement ok +create view t_utc as +select column1::timestamp AT TIME ZONE 'UTC' as "column1" +from t; + +statement ok +create view t_timezone as +select column1::timestamp AT TIME ZONE 'Europe/Brussels' as "column1" +from t; + +query PPT +select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(column1::timestamp)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_timezone; +---- +2024-01-01T00:00:01+01:00 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01+01:00 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01+01:00 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01+02:00 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01+02:00 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01+02:00 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01+02:00 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01+02:00 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01+02:00 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01+02:00 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01+01:00 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01+01:00 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +# combine to_local_time() with date_bin() +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_utc; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_timezone; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +statement ok +drop table t; + +statement ok +drop view t_utc; + +statement ok +drop view t_timezone; From 1e9f0e1d650f0549e6a8f7d6971b7373fae5199c Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Thu, 11 Jul 2024 19:20:10 +0300 Subject: [PATCH 09/14] Implement prettier SQL unparsing (more human readable) (#11186) * initial prettier unparse * bug fix * handling minus and divide * cleaning references and comments * moved tests * Update precedence of BETWEEN * rerun CI * Change precedence to match PGSQLs * more pretty unparser tests * Update operator precedence to match latest PGSQL * directly prettify expr_to_sql * handle IS operator * correct IS precedence * update unparser tests * update unparser example * update more unparser examples * add with_pretty builder to unparser --- .../examples/parse_sql_expr.rs | 9 + datafusion-examples/examples/plan_to_sql.rs | 18 +- datafusion/expr/src/operator.rs | 24 +- datafusion/sql/src/unparser/expr.rs | 230 ++++++++++++++---- datafusion/sql/src/unparser/mod.rs | 15 +- datafusion/sql/tests/cases/plan_to_sql.rs | 99 +++++++- 6 files changed, 319 insertions(+), 76 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index a1fc5d269a04..e23e5accae39 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -153,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { assert_eq!(sql, round_trip_sql); + // enable pretty-unparsing. This make the output more human-readable + // but can be problematic when passed to other SQL engines due to + // difference in precedence rules between DataFusion and target engines. + let unparser = Unparser::default().with_pretty(true); + + let pretty = "int_col < 5 OR double_col = 8"; + let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty, pretty_round_trip_sql); + Ok(()) } diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index bd708fe52bc1..f719a33fb624 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with /// fluent API and convert to sql suitable for passing to another database /// -/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without escaping column names -/// more suitable for displaying to humans. +/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple /// expression [`Exprs`] with fluent API and convert to sql escaping column @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; async fn main() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; + simple_expr_to_pretty_sql_demo()?; simple_expr_to_sql_demo_escape_mysql_style()?; simple_plan_to_sql_demo().await?; round_trip_plan_to_sql_demo().await?; @@ -64,6 +65,17 @@ fn simple_expr_to_sql_demo() -> Result<()> { Ok(()) } +/// DataFusioon can remove parentheses when converting an expression to SQL. +/// Note that output is intended for humans, not for other SQL engines, +/// as difference in precedence rules can cause expressions to be parsed differently. +fn simple_expr_to_pretty_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let unparser = Unparser::default().with_pretty(true); + let sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"a < 5 OR a = 8"#); + Ok(()) +} + /// DataFusion can convert expressions to SQL without escaping column names using /// using a custom dialect and an explicit unparser fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index a10312e23446..9bb8c48d6c71 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -218,29 +218,23 @@ impl Operator { } /// Get the operator precedence - /// use as a reference + /// use as a reference pub fn precedence(&self) -> u8 { match self { Operator::Or => 5, Operator::And => 10, - Operator::NotEq - | Operator::Eq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => 20, - Operator::Plus | Operator::Minus => 30, - Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq => 15, + Operator::Lt | Operator::Gt => 20, + Operator::LikeMatch + | Operator::NotLikeMatch + | Operator::ILikeMatch + | Operator::NotILikeMatch => 25, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexNotMatch | Operator::RegexIMatch | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseShiftLeft @@ -248,7 +242,9 @@ impl Operator { | Operator::BitwiseXor | Operator::StringConcat | Operator::AtArrow - | Operator::ArrowAt => 0, + | Operator::ArrowAt => 30, + Operator::Plus | Operator::Minus => 40, + Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 198186934c84..e0d05c400cb0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo, - UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, }; use datafusion_common::{ @@ -101,8 +101,21 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } +const LOWEST: &BinaryOperator = &BinaryOperator::Or; +// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// (https://www.postgresql.org/docs/7.2/sql-precedence.html) +const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; + impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { + let mut root_expr = self.expr_to_sql_inner(expr)?; + if self.pretty { + root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); + } + Ok(root_expr) + } + + fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { expr, @@ -111,10 +124,10 @@ impl Unparser<'_> { }) => { let list_expr = list .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?; Ok(ast::Expr::InList { - expr: Box::new(self.expr_to_sql(expr)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), list: list_expr, negated: *negated, }) @@ -128,7 +141,7 @@ impl Unparser<'_> { if matches!(e, Expr::Wildcard { qualifier: None }) { Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { - self.expr_to_sql(e).map(|e| { + self.expr_to_sql_inner(e).map(|e| { FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) }) } @@ -157,9 +170,9 @@ impl Unparser<'_> { low, high, }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; - let sql_low = self.expr_to_sql(low)?; - let sql_high = self.expr_to_sql(high)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; + let sql_low = self.expr_to_sql_inner(low)?; + let sql_high = self.expr_to_sql_inner(high)?; Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( sql_parser_expr, *negated, @@ -169,8 +182,8 @@ impl Unparser<'_> { } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = self.expr_to_sql(left.as_ref())?; - let r = self.expr_to_sql(right.as_ref())?; + let l = self.expr_to_sql_inner(left.as_ref())?; + let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) @@ -182,21 +195,21 @@ impl Unparser<'_> { }) => { let conditions = when_then_expr .iter() - .map(|(w, _)| self.expr_to_sql(w)) + .map(|(w, _)| self.expr_to_sql_inner(w)) .collect::>>()?; let results = when_then_expr .iter() - .map(|(_, t)| self.expr_to_sql(t)) + .map(|(_, t)| self.expr_to_sql_inner(t)) .collect::>>()?; let operand = match expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, None => None, }; let else_result = match else_expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, @@ -211,7 +224,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), @@ -220,7 +233,7 @@ impl Unparser<'_> { }) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), + Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, args, @@ -255,7 +268,7 @@ impl Unparser<'_> { window_name: None, partition_by: partition_by .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?, order_by, window_frame: Some(ast::WindowFrame { @@ -296,8 +309,8 @@ impl Unparser<'_> { case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, - expr: Box::new(self.expr_to_sql(expr)?), - pattern: Box::new(self.expr_to_sql(pattern)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { @@ -305,7 +318,7 @@ impl Unparser<'_> { let args = self.function_args_to_sql(&agg.args)?; let filter = match &agg.filter { - Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)), + Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; Ok(ast::Expr::Function(Function { @@ -339,7 +352,7 @@ impl Unparser<'_> { Ok(ast::Expr::Subquery(sub_query)) } Expr::InSubquery(insubq) => { - let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -377,38 +390,38 @@ impl Unparser<'_> { nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { - Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotNull(expr) => { - Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsTrue(expr) => { - Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotTrue(expr) => { - Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsFalse(expr) => { - Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotFalse(expr) => { - Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsUnknown(expr) => { - Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotUnknown(expr) => { - Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?))) - } + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) + } + Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::Not(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } Expr::Negative(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), @@ -432,7 +445,7 @@ impl Unparser<'_> { }) } Expr::TryCast(TryCast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), @@ -449,7 +462,7 @@ impl Unparser<'_> { .iter() .map(|set| { set.iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>() }) .collect::>>()?; @@ -460,7 +473,7 @@ impl Unparser<'_> { let expr_ast_sets = cube .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -470,7 +483,7 @@ impl Unparser<'_> { let expr_ast_sets: Vec> = rollup .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -603,6 +616,88 @@ impl Unparser<'_> { } } + /// Given an expression of the form `((a + b) * (c * d))`, + /// the parenthesing is redundant if the precedence of the nested expression is already higher + /// than the surrounding operators' precedence. The above expression would become + /// `(a + b) * c * d`. + /// + /// Also note that when fetching the precedence of a nested expression, we ignore other nested + /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + fn remove_unnecessary_nesting( + &self, + expr: ast::Expr, + left_op: &BinaryOperator, + right_op: &BinaryOperator, + ) -> ast::Expr { + match expr { + ast::Expr::Nested(nested) => { + let surrounding_precedence = self + .sql_op_precedence(left_op) + .max(self.sql_op_precedence(right_op)); + + let inner_precedence = self.inner_precedence(&nested); + + let not_associative = + matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); + + if inner_precedence == surrounding_precedence && not_associative { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } else if inner_precedence >= surrounding_precedence { + self.remove_unnecessary_nesting(*nested, left_op, right_op) + } else { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } + } + ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { + left: Box::new(self.remove_unnecessary_nesting(*left, left_op, &op)), + right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), + op, + }, + ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + _ => expr, + } + } + + fn inner_precedence(&self, expr: &ast::Expr) -> u8 { + match expr { + ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, + ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + // closest precedence we currently have to Between is PGLikeMatch + // (https://www.postgresql.org/docs/7.2/sql-precedence.html) + ast::Expr::Between { .. } => { + self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) + } + _ => 0, + } + } + pub(super) fn between_op_to_sql( &self, expr: ast::Expr, @@ -618,6 +713,48 @@ impl Unparser<'_> { } } + fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { + match self.sql_to_op(op) { + Ok(op) => op.precedence(), + Err(_) => 0, + } + } + + fn sql_to_op(&self, op: &BinaryOperator) -> Result { + match op { + ast::BinaryOperator::Eq => Ok(Operator::Eq), + ast::BinaryOperator::NotEq => Ok(Operator::NotEq), + ast::BinaryOperator::Lt => Ok(Operator::Lt), + ast::BinaryOperator::LtEq => Ok(Operator::LtEq), + ast::BinaryOperator::Gt => Ok(Operator::Gt), + ast::BinaryOperator::GtEq => Ok(Operator::GtEq), + ast::BinaryOperator::Plus => Ok(Operator::Plus), + ast::BinaryOperator::Minus => Ok(Operator::Minus), + ast::BinaryOperator::Multiply => Ok(Operator::Multiply), + ast::BinaryOperator::Divide => Ok(Operator::Divide), + ast::BinaryOperator::Modulo => Ok(Operator::Modulo), + ast::BinaryOperator::And => Ok(Operator::And), + ast::BinaryOperator::Or => Ok(Operator::Or), + ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), + ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), + ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("unsupported operation: {op:?}"), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), @@ -1538,6 +1675,7 @@ mod tests { Ok(()) } + #[test] fn custom_dialect() -> Result<()> { let dialect = CustomDialect::new(Some('\'')); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fbbed4972b17..e5ffbc8a212a 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,11 +29,23 @@ pub mod dialect; pub struct Unparser<'a> { dialect: &'a dyn Dialect, + pretty: bool, } impl<'a> Unparser<'a> { pub fn new(dialect: &'a dyn Dialect) -> Self { - Self { dialect } + Self { + dialect, + pretty: false, + } + } + + /// Allow unparser to remove parenthesis according to the precedence rules of DataFusion. + /// This might make it invalid SQL for other SQL query engines with different precedence + /// rules, even if its valid for DataFusion. + pub fn with_pretty(mut self, pretty: bool) -> Self { + self.pretty = pretty; + self } } @@ -41,6 +53,7 @@ impl<'a> Default for Unparser<'a> { fn default() -> Self { Self { dialect: &DefaultDialect {}, + pretty: false, } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 374403d853f9..91295b2e8aae 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> { "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", "select id, count(*), first_name from person group by first_name, id", "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person + "select id, count(*), first_name + from person where id!=3 and first_name=='test' - group by first_name, id + group by first_name, id having count(*)>5 and count(*)<10 order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" + r#"select id, count("First Name") as count_first_name, "Last Name" from person_quoted_cols where id!=3 and "First Name"=='test' - group by "Last Name", id + group by "Last Name", id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) from (select id, "First Name", "Last Name" from person_quoted_cols) qp inner join (select * from person) p on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in + where p.id!=3 and "First Name"=='test' and qp.id in (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id + group by "Last Name", p.id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"SELECT j1_string as string FROM j1 @@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, - "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), - last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", - r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, - "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", ]; // For each test sql string, we transform as follows: @@ -314,3 +314,78 @@ fn test_table_references_in_plan_to_sql() { "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", ); } + +#[test] +fn test_pretty_roundtrip() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let df_schema = DFSchema::try_from(schema)?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + + let unparser = Unparser::default().with_pretty(true); + + let sql_to_pretty_unparse = vec![ + ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"), + ("((id + 5) * (age * 8))", "(id + 5) * age * 8"), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"), + ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"), + ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"), + ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"), + ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"), + ( + "((id > 10) OR (age BETWEEN 10 AND 20))", + "id > 10 OR age BETWEEN 10 AND 20", + ), + ( + "((id > 10) * (age BETWEEN 10 AND 20))", + "(id > 10) * (age BETWEEN 10 AND 20)", + ), + ("id - (age - 8)", "id - (age - 8)"), + ("((id - age) - 8)", "id - age - 8"), + ("(id OR (age - 8))", "id OR age - 8"), + ("(id / (age - 8))", "id / (age - 8)"), + ("((id / age) * 8)", "id / age * 8"), + ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"), + ( + "(20 > (age + 5)) IS NOT FALSE", + "(20 > age + 5) IS NOT FALSE", + ), + ("(true AND false) IS FALSE", "(true AND false) IS FALSE"), + ("true AND (false IS FALSE)", "true AND false IS FALSE"), + ]; + + for (sql, pretty) in sql_to_pretty_unparse.iter() { + let sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(sql)? + .parse_expr()?; + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + + // verify that the pretty string parses to the same underlying Expr + let pretty_sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(pretty)? + .parse_expr()?; + + let pretty_expr = sql_to_rel.sql_to_expr( + pretty_sql_expr, + &df_schema, + &mut PlannerContext::new(), + )?; + + assert_eq!(expr.to_string(), pretty_expr.to_string()); + } + + Ok(()) +} From e19dd2d0b91f30b97fd68da894137987c1318b18 Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:21:51 -0500 Subject: [PATCH 10/14] Add `to_local_time()` in function reference docs (#11401) * chore: add document for `to_local_time()` * chore: feedback Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../source/user-guide/sql/scalar_functions.md | 65 ++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d636726b45fe..d2e012cf4093 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1480,6 +1480,7 @@ contains(string, search_string) - [make_date](#make_date) - [to_char](#to_char) - [to_date](#to_date) +- [to_local_time](#to_local_time) - [to_timestamp](#to_timestamp) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) @@ -1710,7 +1711,7 @@ to_char(expression, format) #### Example ``` -> > select to_char('2023-03-01'::date, '%d-%m-%Y'); +> select to_char('2023-03-01'::date, '%d-%m-%Y'); +----------------------------------------------+ | to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | +----------------------------------------------+ @@ -1771,6 +1772,68 @@ to_date(expression[, ..., format_n]) Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or +timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Example + +``` +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + ### `to_timestamp` Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). From 4402a1a9dd8ebec1640b2fa807781a2701407672 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Thu, 11 Jul 2024 21:52:06 +0530 Subject: [PATCH 11/14] Move `overlay` planning to`ExprPlanner` (#11398) * move overlay to expr planner * typo --- datafusion/expr/src/planner.rs | 7 ++++++ datafusion/functions/src/core/planner.rs | 6 +++++ datafusion/functions/src/string/mod.rs | 1 - datafusion/sql/src/expr/mod.rs | 28 ++++++++++++------------ 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index aeb8ed8372b7..2f13923b1f10 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -161,6 +161,13 @@ pub trait ExprPlanner: Send + Sync { ) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])` + /// + /// Returns origin expression arguments if not possible + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } } /// An operator with two arguments to plan diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 748b598d292f..63eaa9874c2b 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -56,4 +56,10 @@ impl ExprPlanner for CoreFunctionPlanner { ), ))) } + + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::string::overlay(), args), + ))) + } } diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 5bf372c29f2d..9a19151a85e2 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -182,7 +182,6 @@ pub fn functions() -> Vec> { lower(), ltrim(), octet_length(), - overlay(), repeat(), replace(), rtrim(), diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 859842e212be..062ef805fd9f 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -193,7 +193,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!("Extract not supported by UserDefinedExtensionPlanners: {extract_args:?}") + not_impl_err!("Extract not supported by ExprPlanner: {extract_args:?}") } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), @@ -292,7 +292,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!("GetFieldAccess not supported by UserDefinedExtensionPlanners: {field_access_expr:?}") + not_impl_err!( + "GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}" + ) } SQLExpr::CompoundIdentifier(ids) => { @@ -657,7 +659,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { PlannerResult::Original(args) => create_struct_args = args, } } - not_impl_err!("Struct not supported by UserDefinedExtensionPlanners: {create_struct_args:?}") + not_impl_err!("Struct not supported by ExprPlanner: {create_struct_args:?}") } fn sql_position_to_expr( @@ -680,9 +682,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - not_impl_err!( - "Position not supported by UserDefinedExtensionPlanners: {position_args:?}" - ) + not_impl_err!("Position not supported by ExprPlanner: {position_args:?}") } fn try_plan_dictionary_literal( @@ -914,18 +914,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = self - .context_provider - .get_function_meta("overlay") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'overlay' function") - })?; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let what_arg = self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; let from_arg = self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; - let args = match overlay_for { + let mut overlay_args = match overlay_for { Some(for_expr) => { let for_expr = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; @@ -933,7 +927,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + for planner in self.planners.iter() { + match planner.plan_overlay(overlay_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => overlay_args = args, + } + } + not_impl_err!("Overlay not supported by ExprPlanner: {overlay_args:?}") } } From d314ced8090cb599fd7808d7df41699e46ac956e Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Thu, 11 Jul 2024 18:22:20 +0200 Subject: [PATCH 12/14] Coerce types for all union children plans when eliminating nesting (#11386) --- .../optimizer/src/eliminate_nested_union.rs | 13 +++++++------ datafusion/sqllogictest/test_files/union.slt | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index c8ae937e128a..cc8cf1f56c18 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -60,7 +60,8 @@ impl OptimizerRule for EliminateNestedUnion { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) - .collect::>(); + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema)) + .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), @@ -74,7 +75,8 @@ impl OptimizerRule for EliminateNestedUnion { .into_iter() .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) - .collect::>(); + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema)) + .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( Arc::new(LogicalPlan::Union(Union { @@ -95,10 +97,9 @@ impl OptimizerRule for EliminateNestedUnion { fn extract_plans_from_union(plan: Arc) -> Vec { match unwrap_arc(plan) { - LogicalPlan::Union(Union { inputs, schema }) => inputs - .into_iter() - .map(|plan| coerce_plan_expr_for_schema(&plan, &schema).unwrap()) - .collect::>(), + LogicalPlan::Union(Union { inputs, .. }) => { + inputs.into_iter().map(unwrap_arc).collect::>() + } plan => vec![plan], } } diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 7b91e97e4a3e..5ede68a42aae 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -135,6 +135,21 @@ SELECT SUM(d) FROM ( ---- 5 +# three way union with aggregate and type coercion +query II rowsort +SELECT c1, SUM(c2) FROM ( + SELECT 1 as c1, 1::int as c2 + UNION + SELECT 2 as c1, 2::int as c2 + UNION + SELECT 3 as c1, COALESCE(3::int, 0) as c2 +) as a +GROUP BY c1 +---- +1 1 +2 2 +3 3 + # union_all_with_count statement ok CREATE table t as SELECT 1 as a From 4bed04e4e312a0b125306944aee94a93c2ff6c4f Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 11 Jul 2024 19:26:46 +0300 Subject: [PATCH 13/14] Add customizable equality and hash functions to UDFs (#11392) * Add customizable equality and hash functions to UDFs * Improve equals and hash_value documentation * Add tests for parameterized UDFs --- .../user_defined/user_defined_aggregates.rs | 79 ++++++++++- .../user_defined_scalar_functions.rs | 128 +++++++++++++++++- datafusion/expr/src/udaf.rs | 73 ++++++++-- datafusion/expr/src/udf.rs | 62 +++++++-- datafusion/expr/src/udwf.rs | 69 ++++++++-- 5 files changed, 367 insertions(+), 44 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index d591c662d877..96de865b6554 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,14 +18,19 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::{ + types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray, +}; +use arrow_schema::Schema; + +use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; use datafusion::test_util::plan_and_collect; use datafusion::{ @@ -45,8 +50,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + LogicalPlanBuilder, SimpleAggregateUDF, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_parameterized_aggregate_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + let udf1 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 1, + }); + let udf2 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 2, + }); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .aggregate( + [col("text")], + [ + udf1.call(vec![col("text")]).alias("a"), + udf2.call(vec![col("text")]).alias("b"), + ], + )? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+------+---+---+", + "| text | a | b |", + "+------+---+---+", + "| foo | 1 | 2 |", + "+------+---+---+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator { ) -> Result> { Ok(Box::new(self.clone())) } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.result == other.result && self.signature == other.signature + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.signature.hash(hasher); + self.result.hash(hasher); + hasher.finish() + } } impl Accumulator for TestGroupsAccumulator { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1733068debb9..5847952ae6a6 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,11 +16,20 @@ // under the License. use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use arrow::compute::kernels::numeric::add; -use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; +use arrow_array::builder::BooleanBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, +}; use arrow_schema::{DataType, Field, Schema}; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -37,8 +46,6 @@ use datafusion_expr::{ Volatility, }; use datafusion_functions_array::range::range_udf; -use parking_lot::Mutex; -use sqlparser::ast::Ident; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( Ok(()) } +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + regex: Regex, +} + +impl MyRegexUdf { + fn new(pattern: &str) -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + regex: Regex::new(pattern).expect("regex"), + } + } + + fn matches(&self, value: Option<&str>) -> Option { + Some(self.regex.is_match(value?)) + } +} + +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regex_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Boolean) + } else { + plan_err!("regex_udf only accepts a Utf8 argument") + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + self.matches(value.as_deref()), + ))) + } + [ColumnarValue::Array(values)] => { + let mut builder = BooleanBuilder::with_capacity(values.len()); + for value in values.as_string::() { + builder.append_option(self.matches(value)) + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.regex.as_str() == other.regex.as_str() + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.regex.as_str().hash(hasher); + hasher.finish() + } +} + +#[tokio::test] +async fn test_parameterized_scalar_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); + let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .filter( + foo_udf + .call(vec![col("text")]) + .and(bar_udf.call(vec![col("text")])), + )? + .filter(col("text").is_not_null())? + .build()?; + + assert_eq!( + format!("{plan:?}"), + "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+--------+", + "| text |", + "+--------+", + "| foobar |", + "| barfoo |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7a054abea75b..1657e034fbe2 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,17 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; +use std::vec; + +use arrow::datatypes::{DataType, Field}; +use sqlparser::ast::NullTreatment; + +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; + use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, @@ -26,13 +37,6 @@ use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; -use sqlparser::ast::NullTreatment; -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -72,20 +76,19 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for AggregateUDF {} -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } -impl std::fmt::Display for AggregateUDF { +impl fmt::Display for AggregateUDF { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{}", self.name()) } @@ -280,7 +283,7 @@ where /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature -/// }; +/// } /// /// impl GeoMeanUdf { /// fn new() -> Self { @@ -507,6 +510,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this aggregate UDF is equal to the other. + /// + /// Allows customizing the equality of aggregate UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this aggregate UDF. + /// + /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } pub enum ReversedUDAF { @@ -562,6 +592,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 68d3af6ace3c..1fbb3cc584b3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,8 +19,13 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::datatypes::DataType; + +use datafusion_common::{not_impl_err, ExprSchema, Result}; + use crate::expr::create_name; use crate::interval_arithmetic::Interval; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -29,9 +34,6 @@ use crate::{ ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; - /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -59,16 +61,15 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for ScalarUDF {} -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -294,7 +295,7 @@ where /// #[derive(Debug)] /// struct AddOne { /// signature: Signature -/// }; +/// } /// /// impl AddOne { /// fn new() -> Self { @@ -540,6 +541,33 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Return true if this scalar UDF is equal to the other. + /// + /// Allows customizing the equality of scalar UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this scalar UDF. + /// + /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -557,7 +585,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -586,6 +613,21 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 70b44e5e307a..1a6b21e3dd29 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,18 +17,22 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::DataType; + +use datafusion_common::Result; + +use crate::{ + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, +}; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// @@ -62,16 +66,15 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for WindowUDF {} -impl std::hash::Hash for WindowUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } @@ -212,7 +215,7 @@ where /// #[derive(Debug, Clone)] /// struct SmoothIt { /// signature: Signature -/// }; +/// } /// /// impl SmoothIt { /// fn new() -> Self { @@ -296,6 +299,33 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Return true if this window UDF is equal to the other. + /// + /// Allows customizing the equality of window UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this window UDF. + /// + /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -342,6 +372,21 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } } /// Implementation of [`WindowUDFImpl`] that wraps the function style pointers From 5ba634aa4f6d3d4ed5eefbc15dba5448f4f30923 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 12 Jul 2024 14:43:49 +0800 Subject: [PATCH 14/14] Implement ScalarFunction `MAKE_MAP` and `MAP` (#11361) * tmp * opt * modify test * add another version * implement make_map function * implement make_map function * implement map function * format and modify the doc * add benchmark for map function * add empty end-line * fix cargo check * update lock * upate lock * fix clippy * fmt and clippy * support FixedSizeList and LargeList * check type and handle null array in coerce_types * make array value throw todo error * fix clippy * simpify the error tests --- datafusion-cli/Cargo.lock | 1 + datafusion/functions/Cargo.toml | 7 +- datafusion/functions/benches/map.rs | 101 +++++++ datafusion/functions/src/core/map.rs | 312 +++++++++++++++++++++ datafusion/functions/src/core/mod.rs | 13 + datafusion/sqllogictest/test_files/map.slt | 112 ++++++++ 6 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 datafusion/functions/benches/map.rs create mode 100644 datafusion/functions/src/core/map.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8af42cb43932..7da9cc427c37 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1278,6 +1278,7 @@ name = "datafusion-functions" version = "40.0.0" dependencies = [ "arrow", + "arrow-buffer", "base64 0.22.1", "blake2", "blake3", diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 884a66724c91..b143080b1962 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -66,6 +66,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-buffer = { workspace = true } base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } @@ -86,7 +87,6 @@ uuid = { version = "1.7", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } -arrow-buffer = { workspace = true } criterion = "0.5" rand = { workspace = true } rstest = { workspace = true } @@ -141,3 +141,8 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "map" +required-features = ["core_expressions"] diff --git a/datafusion/functions/benches/map.rs b/datafusion/functions/benches/map.rs new file mode 100644 index 000000000000..cd863d0e3311 --- /dev/null +++ b/datafusion/functions/benches/map.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{Int32Array, ListArray, StringArray}; +use arrow::datatypes::{DataType, Field}; +use arrow_buffer::{OffsetBuffer, ScalarBuffer}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions::core::{make_map, map}; +use rand::prelude::ThreadRng; +use rand::Rng; +use std::sync::Arc; + +fn keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("make_map_1000", |b| { + let mut rng = rand::thread_rng(); + let keys = keys(&mut rng); + let values = values(&mut rng); + let mut buffer = Vec::new(); + for i in 0..1000 { + buffer.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + keys[i].clone(), + )))); + buffer.push(ColumnarValue::Scalar(ScalarValue::Int32(Some(values[i])))); + } + + b.iter(|| { + black_box( + make_map() + .invoke(&buffer) + .expect("map should work on valid values"), + ); + }); + }); + + c.bench_function("map_1000", |b| { + let mut rng = rand::thread_rng(); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let key_list = ListArray::new( + field, + offsets, + Arc::new(StringArray::from(keys(&mut rng))), + None, + ); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let value_list = ListArray::new( + field, + offsets, + Arc::new(Int32Array::from(values(&mut rng))), + None, + ); + let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + + b.iter(|| { + black_box( + map() + .invoke(&[keys.clone(), values.clone()]) + .expect("map should work on valid values"), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions/src/core/map.rs new file mode 100644 index 000000000000..8a8a19d7af52 --- /dev/null +++ b/datafusion/functions/src/core/map.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::VecDeque; +use std::sync::Arc; + +use arrow::array::{Array, ArrayData, ArrayRef, MapArray, StructArray}; +use arrow::compute::concat; +use arrow::datatypes::{DataType, Field, SchemaBuilder}; +use arrow_buffer::{Buffer, ToByteSlice}; + +use datafusion_common::{exec_err, internal_err, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +fn make_map(args: &[ColumnarValue]) -> Result { + let (key, value): (Vec<_>, Vec<_>) = args + .chunks_exact(2) + .map(|chunk| { + if let ColumnarValue::Array(_) = chunk[0] { + return not_impl_err!("make_map does not support array keys"); + } + if let ColumnarValue::Array(_) = chunk[1] { + return not_impl_err!("make_map does not support array values"); + } + Ok((chunk[0].clone(), chunk[1].clone())) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let keys = ColumnarValue::values_to_arrays(&key)?; + let values = ColumnarValue::values_to_arrays(&value)?; + + let keys: Vec<_> = keys.iter().map(|k| k.as_ref()).collect(); + let values: Vec<_> = values.iter().map(|v| v.as_ref()).collect(); + + let key = match concat(&keys) { + Ok(key) => key, + Err(e) => return internal_err!("Error concatenating keys: {}", e), + }; + let value = match concat(&values) { + Ok(value) => value, + Err(e) => return internal_err!("Error concatenating values: {}", e), + }; + make_map_batch_internal(key, value) +} + +fn make_map_batch(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!( + "make_map requires exactly 2 arguments, got {} instead", + args.len() + ); + } + let key = get_first_array_ref(&args[0])?; + let value = get_first_array_ref(&args[1])?; + make_map_batch_internal(key, value) +} + +fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { + match columnar_value { + ColumnarValue::Scalar(value) => match value { + ScalarValue::List(array) => Ok(array.value(0).clone()), + ScalarValue::LargeList(array) => Ok(array.value(0).clone()), + ScalarValue::FixedSizeList(array) => Ok(array.value(0).clone()), + _ => exec_err!("Expected array, got {:?}", value), + }, + ColumnarValue::Array(array) => exec_err!("Expected scalar, got {:?}", array), + } +} + +fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) -> Result { + if keys.null_count() > 0 { + return exec_err!("map key cannot be null"); + } + + if keys.len() != values.len() { + return exec_err!("map requires key and value lists to have the same length"); + } + + let key_field = Arc::new(Field::new("key", keys.data_type().clone(), false)); + let value_field = Arc::new(Field::new("value", values.data_type().clone(), true)); + let mut entry_struct_buffer: VecDeque<(Arc, ArrayRef)> = VecDeque::new(); + let mut entry_offsets_buffer = VecDeque::new(); + entry_offsets_buffer.push_back(0); + + entry_struct_buffer.push_back((Arc::clone(&key_field), Arc::clone(&keys))); + entry_struct_buffer.push_back((Arc::clone(&value_field), Arc::clone(&values))); + entry_offsets_buffer.push_back(keys.len() as u32); + + let entry_struct: Vec<(Arc, ArrayRef)> = entry_struct_buffer.into(); + let entry_struct = StructArray::from(entry_struct); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + + let entry_offsets: Vec = entry_offsets_buffer.into(); + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.to_data()) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data)))) +} + +#[derive(Debug)] +pub struct MakeMap { + signature: Signature, +} + +impl Default for MakeMap { + fn default() -> Self { + Self::new() + } +} + +impl MakeMap { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MakeMap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.is_empty() { + return exec_err!( + "make_map requires at least one pair of arguments, got 0 instead" + ); + } + if arg_types.len() % 2 != 0 { + return exec_err!( + "make_map requires an even number of arguments, got {} instead", + arg_types.len() + ); + } + + let key_type = &arg_types[0]; + let mut value_type = &arg_types[1]; + + for (i, chunk) in arg_types.chunks_exact(2).enumerate() { + if chunk[0].is_null() { + return exec_err!("make_map key cannot be null at position {}", i); + } + if &chunk[0] != key_type { + return exec_err!( + "make_map requires all keys to have the same type {}, got {} instead at position {}", + key_type, + chunk[0], + i + ); + } + + if !chunk[1].is_null() { + if value_type.is_null() { + value_type = &chunk[1]; + } else if &chunk[1] != value_type { + return exec_err!( + "map requires all values to have the same type {}, got {} instead at position {}", + value_type, + &chunk[1], + i + ); + } + } + } + + let mut result = Vec::new(); + for _ in 0..arg_types.len() / 2 { + result.push(key_type.clone()); + result.push(value_type.clone()); + } + + Ok(result) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let key_type = &arg_types[0]; + let mut value_type = &arg_types[1]; + + for chunk in arg_types.chunks_exact(2) { + if !chunk[1].is_null() && value_type.is_null() { + value_type = &chunk[1]; + } + } + + let mut builder = SchemaBuilder::new(); + builder.push(Field::new("key", key_type.clone(), false)); + builder.push(Field::new("value", value_type.clone(), true)); + let fields = builder.finish().fields; + Ok(DataType::Map( + Arc::new(Field::new("entries", DataType::Struct(fields), false)), + false, + )) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_map(args) + } +} + +#[derive(Debug)] +pub struct MapFunc { + signature: Signature, +} + +impl Default for MapFunc { + fn default() -> Self { + Self::new() + } +} + +impl MapFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() % 2 != 0 { + return exec_err!( + "map requires an even number of arguments, got {} instead", + arg_types.len() + ); + } + let mut builder = SchemaBuilder::new(); + builder.push(Field::new( + "key", + get_element_type(&arg_types[0])?.clone(), + false, + )); + builder.push(Field::new( + "value", + get_element_type(&arg_types[1])?.clone(), + true, + )); + let fields = builder.finish().fields; + Ok(DataType::Map( + Arc::new(Field::new("entries", DataType::Struct(fields), false)), + false, + )) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_map_batch(args) + } +} + +fn get_element_type(data_type: &DataType) -> Result<&DataType> { + match data_type { + DataType::List(element) => Ok(element.data_type()), + DataType::LargeList(element) => Ok(element.data_type()), + DataType::FixedSizeList(element, _) => Ok(element.data_type()), + _ => exec_err!( + "Expected list, large_list or fixed_size_list, got {:?}", + data_type + ), + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 062a4a104d54..31bce04beec1 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -25,6 +25,7 @@ pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; pub mod getfield; +pub mod map; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -42,6 +43,8 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); +make_udf_function!(map::MakeMap, MAKE_MAP, make_map); +make_udf_function!(map::MapFunc, MAP, map); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -78,6 +81,14 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, + ),( + make_map, + "Returns a map created from the given keys and values pairs. This function isn't efficient for large maps. Use the `map` function instead.", + args, + ),( + map, + "Returns a map created from a key list and a value list", + args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -96,5 +107,7 @@ pub fn functions() -> Vec> { named_struct(), get_field(), coalesce(), + make_map(), + map(), ] } diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 417947dc6c89..abf5b2ebbf98 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -100,3 +100,115 @@ physical_plan statement ok drop table table_with_map; + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST']; +---- +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null); +---- +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null); +---- +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); +---- +{[1, 2]: [a, b], [3, 4]: [b]} + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); + +query error +SELECT MAKE_MAP('POST', 41, 123, 33,'PATCH', 30); + +query error +SELECT MAKE_MAP() + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD'); + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']); +---- +{[1, 2]: a, [3, 4]: b} + +query error +SELECT MAP() + +query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead +SELECT MAP(['POST', 'HEAD']) + +query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null +SELECT MAP(null, [41, 33, 30]); + +query error DataFusion error: Execution error: map requires key and value lists to have the same length +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33]); + +query error DataFusion error: Execution error: map key cannot be null +SELECT MAP(['POST', 'HEAD', null], [41, 33, 30]); + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +statement ok +create table t as values +('a', 1, 'k1', 10, ['k1', 'k2'], [1, 2]), +('b', 2, 'k3', 30, ['k3'], [3]), +('d', 4, 'k5', 50, ['k5'], [5]); + +query error +SELECT make_map(column1, column2, column3, column4) FROM t; +# TODO: support array value +# ---- +# {a: 1, k1: 10} +# {b: 2, k3: 30} +# {d: 4, k5: 50} + +query error +SELECT map(column5, column6) FROM t; +# TODO: support array value +# ---- +# {k1:1, k2:2} +# {k3: 3} +# {k5: 5}