From d3a38d195b1ca7e5cbcb5ad86cfadb0ddb9062f2 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 17 Aug 2022 12:29:55 +0800 Subject: [PATCH 1/3] add rule pre add cast to literal --- datafusion/core/src/execution/context.rs | 2 + .../core/tests/provider_filter_pushdown.rs | 14 + datafusion/core/tests/sql/explain_analyze.rs | 44 +-- datafusion/core/tests/sql/subqueries.rs | 4 +- datafusion/optimizer/src/lib.rs | 1 + .../src/pre_cast_lit_in_binary_comparison.rs | 295 ++++++++++++++++++ 6 files changed, 336 insertions(+), 24 deletions(-) create mode 100644 datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 96705bb0cab5..601a812bfc8d 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::pre_cast_lit_in_binary_comparison::PreCastLitInBinaryComparisonExpressions; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_sql::{ parser::DFParser, @@ -1360,6 +1361,7 @@ impl SessionState { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), + Arc::new(PreCastLitInBinaryComparisonExpressions::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(DecorrelateScalarSubquery::new()), diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 3ebfec996e64..283a39a85d29 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -31,6 +31,7 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; +use std::ops::Deref; use std::sync::Arc; fn create_batch(value: i32, num_rows: usize) -> Result { @@ -146,7 +147,20 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr { right, .. } => { let int_value = match &**right { + Expr::Literal(ScalarValue::Int8(i)) => i.unwrap() as i64, + Expr::Literal(ScalarValue::Int16(i)) => i.unwrap() as i64, + Expr::Literal(ScalarValue::Int32(i)) => i.unwrap() as i64, Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(), + Expr::Cast { expr, data_type: _ } => match expr.deref() { + Expr::Literal(lit_value) => match lit_value { + ScalarValue::Int8(v) => v.unwrap() as i64, + ScalarValue::Int16(v) => v.unwrap() as i64, + ScalarValue::Int32(v) => v.unwrap() as i64, + ScalarValue::Int64(v) => v.unwrap(), + _ => unimplemented!(), + }, + _ => unimplemented!(), + }, _ => unimplemented!(), }; diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 02db3e873330..2b801ed01cb9 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -271,8 +271,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -286,8 +286,8 @@ async fn csv_explain_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", + " Filter: #aggregate_test_100.c2 > Int32(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -307,9 +307,9 @@ async fn csv_explain_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -318,9 +318,9 @@ async fn csv_explain_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -349,7 +349,7 @@ async fn csv_explain_plans() { // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content assert_contains!(&actual, "logical_plan"); assert_contains!(&actual, "Projection: #aggregate_test_100.c1"); - assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int64(10)"); + assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)"); } #[tokio::test] @@ -469,8 +469,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #aggregate_test_100.c1 [c1:Utf8]", - " Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)] [c1:Utf8, c2:Int32]", + " Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -484,8 +484,8 @@ async fn csv_explain_verbose_plans() { let expected = vec![ "Explain", " Projection: #aggregate_test_100.c1", - " Filter: #aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]", + " Filter: #aggregate_test_100.c2 > Int32(10)", + " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]", ]; let formatted = plan.display_indent().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -505,9 +505,9 @@ async fn csv_explain_verbose_plans() { " 2[shape=box label=\"Explain\"]", " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\"]", + " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", " }", " subgraph cluster_6", @@ -516,9 +516,9 @@ async fn csv_explain_verbose_plans() { " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", " }", "}", @@ -549,7 +549,7 @@ async fn csv_explain_verbose_plans() { // important content assert_contains!(&actual, "logical_plan after projection_push_down"); assert_contains!(&actual, "physical_plan"); - assert_contains!(&actual, "FilterExec: CAST(c2@1 AS Int64) > 10"); + assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); } @@ -745,7 +745,7 @@ async fn csv_explain() { // then execute the physical plan and return the final explain results let ctx = SessionContext::new(); register_aggregate_csv_by_sql(&ctx).await; - let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; + let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > cast(10 as int)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); @@ -755,13 +755,13 @@ async fn csv_explain() { vec![ "logical_plan", "Projection: #aggregate_test_100.c1\ - \n Filter: #aggregate_test_100.c2 > Int64(10)\ - \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int64(10)]" + \n Filter: #aggregate_test_100.c2 > Int32(10)\ + \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]" ], vec!["physical_plan", "ProjectionExec: expr=[c1@0 as c1]\ \n CoalesceBatchesExec: target_batch_size=4096\ - \n FilterExec: CAST(c2@1 AS Int64) > 10\ + \n FilterExec: c2@1 > 10\ \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ \n" diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 4eaf921f6937..d85a2693253a 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -147,8 +147,8 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey Inner Join: #part.p_partkey = #partsupp.ps_partkey - Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")] + Filter: #part.p_size = Int32(15) AND #part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int32(15), #part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] TableScan: nation projection=[n_nationkey, n_name, n_regionkey] diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 6da67b6fc132..cb79f2164ed1 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby; pub mod subquery_filter_to_join; pub mod utils; +pub mod pre_cast_lit_in_binary_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; diff --git a/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs new file mode 100644 index 000000000000..490c5a25c0d6 --- /dev/null +++ b/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. +//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. +use crate::{OptimizerConfig, OptimizerRule}; +use arrow::datatypes::DataType; +use datafusion_common::{DFSchemaRef, Result, ScalarValue}; +use datafusion_expr::utils::from_plan; +use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; + +/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: +/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. +/// The data type of two sides must be signed numeric type now, and will support more data type later. +/// +/// If the binary comparison expr match above rules, the optimizer will check if the value of `literal` +/// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`. +/// +/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of +/// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or +/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization, +/// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr` +/// which data type is `target_type`. +/// If this false, do nothing. +/// +/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark. +/// # Example +/// +/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), +/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. +/// +#[derive(Default)] +pub struct PreCastLitInBinaryComparisonExpressions {} + +impl PreCastLitInBinaryComparisonExpressions { + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for PreCastLitInBinaryComparisonExpressions { + fn optimize( + &self, + plan: &LogicalPlan, + _optimizer_config: &mut OptimizerConfig, + ) -> Result { + optimize(plan) + } + + fn name(&self) -> &str { + "pre_cast_lit_in_binary_comparison" + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + let new_inputs = plan + .inputs() + .iter() + .map(|input| optimize(input)) + .collect::>>()?; + + let schema = plan.schema(); + let new_exprs = plan + .expressions() + .into_iter() + .map(|expr| visit_expr(expr, schema)) + .collect::>(); + + from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) +} + +// Visit all type of expr, if the current has child expr, the child expr needed to visit first. +fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr { + // traverse the expr by dfs + match &expr { + Expr::BinaryExpr { left, op, right } => { + // dfs visit the left and right expr + let left = visit_expr(*left.clone(), schema); + let right = visit_expr(*right.clone(), schema); + let left_type = left.get_type(schema); + let right_type = right.get_type(schema); + // can't get the data type, just return the expr + if left_type.is_err() || right_type.is_err() { + return expr.clone(); + } + let left_type = left_type.unwrap(); + let right_type = right_type.unwrap(); + if !left_type.eq(&right_type) + && is_support_data_type(&left_type) + && is_support_data_type(&right_type) + && is_comparison_op(op) + { + match (&left, &right) { + (Expr::Literal(_), Expr::Literal(_)) => { + // do nothing + } + (Expr::Literal(left_lit_value), _) + if can_integer_literal_cast_to_type( + left_lit_value, + &right_type, + ) => + { + // cast the left literal to the right type + return binary_expr( + cast_to_other_scalar_expr(left_lit_value, &right_type), + *op, + right, + ); + } + (_, Expr::Literal(right_lit_value)) + if can_integer_literal_cast_to_type( + right_lit_value, + &left_type, + ) => + { + // cast the right literal to the left type + return binary_expr( + left, + *op, + cast_to_other_scalar_expr(right_lit_value, &left_type), + ); + } + (_, _) => { + // do nothing + } + }; + } + // return the new binary op + binary_expr(left, *op, right) + } + // TODO: optimize in list + // Expr::InList { .. } => {} + // TODO: handle other expr type and dfs visit them + _ => expr, + } +} + +fn cast_to_other_scalar_expr(origin_value: &ScalarValue, target_type: &DataType) -> Expr { + // null case + if origin_value.is_null() { + // if the origin value is null, just convert to another type of null value + // The target type must be satisfied `is_support_data_type` method, we can unwrap safely + return lit(ScalarValue::try_from(target_type).unwrap()); + } + // no null case + let value: i64 = match origin_value { + ScalarValue::Int8(Some(v)) => *v as i64, + ScalarValue::Int16(Some(v)) => *v as i64, + ScalarValue::Int32(Some(v)) => *v as i64, + ScalarValue::Int64(Some(v)) => *v as i64, + other_type => { + panic!("Invalid type and value {:?}", other_type); + } + }; + lit(match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value)), + other_type => { + panic!("Invalid target data type {:?}", other_type); + } + }) +} + +fn is_comparison_op(op: &Operator) -> bool { + matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Gt + | Operator::GtEq + | Operator::Lt + | Operator::LtEq + ) +} + +fn is_support_data_type(data_type: &DataType) -> bool { + // TODO support decimal with other data type + matches!( + data_type, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + ) +} + +fn can_integer_literal_cast_to_type( + integer_lit_value: &ScalarValue, + target_type: &DataType, +) -> bool { + if integer_lit_value.is_null() { + // null value can be cast to any type of null value + return true; + } + let (target_min, target_max) = match target_type { + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + _ => panic!("Error target data type {:?}", target_type), + }; + let lit_value = match integer_lit_value { + ScalarValue::Int8(Some(v)) => *v as i128, + ScalarValue::Int16(Some(v)) => *v as i128, + ScalarValue::Int32(Some(v)) => *v as i128, + ScalarValue::Int64(Some(v)) => *v as i128, + _ => { + panic!("Invalid literal value {:?}", integer_lit_value) + } + }; + if lit_value >= target_min && lit_value <= target_max { + return true; + } + false +} + +#[cfg(test)] +mod tests { + use crate::pre_cast_lit_in_binary_comparison::visit_expr; + use arrow::datatypes::DataType; + use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::{col, lit, Expr}; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn test_not_cast_lit_comparison() { + let schema = expr_test_schema(); + // INT8(NULL) < INT32(12) + let lit_lt_lit = + lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12)))); + assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit); + // INT32(c1) > INT64(c2) + let c1_gt_c2 = col("c1").gt(col("c2")); + assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); + + // INT32(c1) < INT32(16), the type is same + let expr_lt = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999)))); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + } + + #[test] + fn test_pre_cast_lit_comparison() { + let schema = expr_test_schema(); + // c1 < INT64(16) -> c1 < cast(INT32(16)) + // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16)))); + let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16)))); + assert_eq!(optimize_test(c2_eq_lit.clone(), &schema), expected); + + // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None))); + let expected = col("c1").lt(lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(c1_lt_lit_null.clone(), &schema), expected); + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + visit_expr(expr, schema) + } + + fn expr_test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::new_with_metadata( + vec![ + DFField::new(None, "c1", DataType::Int32, false), + DFField::new(None, "c2", DataType::Int64, false), + ], + HashMap::new(), + ) + .unwrap(), + ) + } +} From eae51335f2a805b408149733458a67900866dbe2 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Sat, 20 Aug 2022 15:16:41 +0800 Subject: [PATCH 2/3] address comments and fix clippy --- .../core/tests/provider_filter_pushdown.rs | 16 ++++++++-------- .../src/pre_cast_lit_in_binary_comparison.rs | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 283a39a85d29..762988851300 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -147,16 +147,16 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr { right, .. } => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(i)) => i.unwrap() as i64, - Expr::Literal(ScalarValue::Int16(i)) => i.unwrap() as i64, - Expr::Literal(ScalarValue::Int32(i)) => i.unwrap() as i64, - Expr::Literal(ScalarValue::Int64(i)) => i.unwrap(), + Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64, Expr::Cast { expr, data_type: _ } => match expr.deref() { Expr::Literal(lit_value) => match lit_value { - ScalarValue::Int8(v) => v.unwrap() as i64, - ScalarValue::Int16(v) => v.unwrap() as i64, - ScalarValue::Int32(v) => v.unwrap() as i64, - ScalarValue::Int64(v) => v.unwrap(), + ScalarValue::Int8(Some(v)) => *v as i64, + ScalarValue::Int16(Some(v)) => *v as i64, + ScalarValue::Int32(Some(v)) => *v as i64, + ScalarValue::Int64(Some(v)) => *v, _ => unimplemented!(), }, _ => unimplemented!(), diff --git a/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs index 490c5a25c0d6..e2d9901e0ee1 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs @@ -268,12 +268,12 @@ mod tests { // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16)))); let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16)))); - assert_eq!(optimize_test(c2_eq_lit.clone(), &schema), expected); + assert_eq!(optimize_test(c2_eq_lit, &schema), expected); // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None))); let expected = col("c1").lt(lit(ScalarValue::Int32(None))); - assert_eq!(optimize_test(c1_lt_lit_null.clone(), &schema), expected); + assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { From 8455e3cc35ff2d4650600bd34d13777471f44c99 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 23 Aug 2022 11:24:31 +0800 Subject: [PATCH 3/3] change panic to result --- datafusion/core/src/execution/context.rs | 4 +- .../core/tests/provider_filter_pushdown.rs | 22 ++++- datafusion/optimizer/src/lib.rs | 2 +- ...rison.rs => pre_cast_lit_in_comparison.rs} | 92 +++++++++++-------- 4 files changed, 76 insertions(+), 44 deletions(-) rename datafusion/optimizer/src/{pre_cast_lit_in_binary_comparison.rs => pre_cast_lit_in_comparison.rs} (83%) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 601a812bfc8d..877ad5c0579f 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -106,7 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::pre_cast_lit_in_binary_comparison::PreCastLitInBinaryComparisonExpressions; +use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_sql::{ parser::DFParser, @@ -1361,7 +1361,7 @@ impl SessionState { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), - Arc::new(PreCastLitInBinaryComparisonExpressions::new()), + Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(DecorrelateScalarSubquery::new()), diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 762988851300..8e6d695c9e9c 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -31,6 +31,7 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; +use datafusion_common::DataFusionError; use std::ops::Deref; use std::sync::Arc; @@ -157,11 +158,26 @@ impl TableProvider for CustomProvider { ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, ScalarValue::Int64(Some(v)) => *v, - _ => unimplemented!(), + other_value => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support value {:?}", + other_value + ))) + } }, - _ => unimplemented!(), + other_expr => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support expr {:?}", + other_expr + ))) + } }, - _ => unimplemented!(), + other_expr => { + return Err(DataFusionError::NotImplemented(format!( + "Do not support expr {:?}", + other_expr + ))) + } }; Ok(Arc::new(CustomPlan { diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index cb79f2164ed1..60c450992de5 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -33,7 +33,7 @@ pub mod single_distinct_to_groupby; pub mod subquery_filter_to_join; pub mod utils; -pub mod pre_cast_lit_in_binary_comparison; +pub mod pre_cast_lit_in_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; diff --git a/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs similarity index 83% rename from datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs rename to datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index e2d9901e0ee1..0c16f7921c32 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -19,7 +19,7 @@ //! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, Result, ScalarValue}; +use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::utils::from_plan; use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; @@ -44,15 +44,15 @@ use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operat /// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. /// #[derive(Default)] -pub struct PreCastLitInBinaryComparisonExpressions {} +pub struct PreCastLitInComparisonExpressions {} -impl PreCastLitInBinaryComparisonExpressions { +impl PreCastLitInComparisonExpressions { pub fn new() -> Self { Self::default() } } -impl OptimizerRule for PreCastLitInBinaryComparisonExpressions { +impl OptimizerRule for PreCastLitInComparisonExpressions { fn optimize( &self, plan: &LogicalPlan, @@ -62,7 +62,7 @@ impl OptimizerRule for PreCastLitInBinaryComparisonExpressions { } fn name(&self) -> &str { - "pre_cast_lit_in_binary_comparison" + "pre_cast_lit_in_comparison" } } @@ -78,24 +78,24 @@ fn optimize(plan: &LogicalPlan) -> Result { .expressions() .into_iter() .map(|expr| visit_expr(expr, schema)) - .collect::>(); + .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } // Visit all type of expr, if the current has child expr, the child expr needed to visit first. -fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr { +fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { // traverse the expr by dfs match &expr { Expr::BinaryExpr { left, op, right } => { // dfs visit the left and right expr - let left = visit_expr(*left.clone(), schema); - let right = visit_expr(*right.clone(), schema); + let left = visit_expr(*left.clone(), schema)?; + let right = visit_expr(*right.clone(), schema)?; let left_type = left.get_type(schema); let right_type = right.get_type(schema); // can't get the data type, just return the expr if left_type.is_err() || right_type.is_err() { - return expr.clone(); + return Ok(expr.clone()); } let left_type = left_type.unwrap(); let right_type = right_type.unwrap(); @@ -112,27 +112,28 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr { if can_integer_literal_cast_to_type( left_lit_value, &right_type, - ) => + )? => { // cast the left literal to the right type - return binary_expr( - cast_to_other_scalar_expr(left_lit_value, &right_type), + return Ok(binary_expr( + cast_to_other_scalar_expr(left_lit_value, &right_type)?, *op, right, - ); + )); } (_, Expr::Literal(right_lit_value)) if can_integer_literal_cast_to_type( right_lit_value, &left_type, - ) => + ) + .unwrap() => { // cast the right literal to the left type - return binary_expr( + return Ok(binary_expr( left, *op, - cast_to_other_scalar_expr(right_lit_value, &left_type), - ); + cast_to_other_scalar_expr(right_lit_value, &left_type)?, + )); } (_, _) => { // do nothing @@ -140,21 +141,24 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr { }; } // return the new binary op - binary_expr(left, *op, right) + Ok(binary_expr(left, *op, right)) } // TODO: optimize in list // Expr::InList { .. } => {} // TODO: handle other expr type and dfs visit them - _ => expr, + _ => Ok(expr), } } -fn cast_to_other_scalar_expr(origin_value: &ScalarValue, target_type: &DataType) -> Expr { +fn cast_to_other_scalar_expr( + origin_value: &ScalarValue, + target_type: &DataType, +) -> Result { // null case if origin_value.is_null() { // if the origin value is null, just convert to another type of null value // The target type must be satisfied `is_support_data_type` method, we can unwrap safely - return lit(ScalarValue::try_from(target_type).unwrap()); + return Ok(lit(ScalarValue::try_from(target_type).unwrap())); } // no null case let value: i64 = match origin_value { @@ -162,19 +166,25 @@ fn cast_to_other_scalar_expr(origin_value: &ScalarValue, target_type: &DataType) ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, ScalarValue::Int64(Some(v)) => *v as i64, - other_type => { - panic!("Invalid type and value {:?}", other_type); + other_value => { + return Err(DataFusionError::Internal(format!( + "Invalid type and value {}", + other_value + ))) } }; - lit(match target_type { + Ok(lit(match target_type { DataType::Int8 => ScalarValue::Int8(Some(value as i8)), DataType::Int16 => ScalarValue::Int16(Some(value as i16)), DataType::Int32 => ScalarValue::Int32(Some(value as i32)), DataType::Int64 => ScalarValue::Int64(Some(value)), other_type => { - panic!("Invalid target data type {:?}", other_type); + return Err(DataFusionError::Internal(format!( + "Invalid target data type {:?}", + other_type + ))) } - }) + })) } fn is_comparison_op(op: &Operator) -> bool { @@ -200,36 +210,42 @@ fn is_support_data_type(data_type: &DataType) -> bool { fn can_integer_literal_cast_to_type( integer_lit_value: &ScalarValue, target_type: &DataType, -) -> bool { +) -> Result { if integer_lit_value.is_null() { // null value can be cast to any type of null value - return true; + return Ok(true); } let (target_min, target_max) = match target_type { DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - _ => panic!("Error target data type {:?}", target_type), + other_type => { + return Err(DataFusionError::Internal(format!( + "Error target data type {:?}", + other_type + ))) + } }; let lit_value = match integer_lit_value { ScalarValue::Int8(Some(v)) => *v as i128, ScalarValue::Int16(Some(v)) => *v as i128, ScalarValue::Int32(Some(v)) => *v as i128, ScalarValue::Int64(Some(v)) => *v as i128, - _ => { - panic!("Invalid literal value {:?}", integer_lit_value) + other_value => { + return Err(DataFusionError::Internal(format!( + "Invalid literal value {:?}", + other_value + ))) } }; - if lit_value >= target_min && lit_value <= target_max { - return true; - } - false + + Ok(lit_value >= target_min && lit_value <= target_max) } #[cfg(test)] mod tests { - use crate::pre_cast_lit_in_binary_comparison::visit_expr; + use crate::pre_cast_lit_in_comparison::visit_expr; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::{col, lit, Expr}; @@ -277,7 +293,7 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - visit_expr(expr, schema) + visit_expr(expr, schema).unwrap() } fn expr_test_schema() -> DFSchemaRef {