From a4e74c0bc9b6e46dd151d40e5c881b7961fecccc Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Tue, 30 Aug 2022 17:48:15 +0800 Subject: [PATCH] support inlist for pre cast literal expression (#3270) * support decimal for the PreCastLitInComparisonExpressions rule * address comments * support list --- .../src/pre_cast_lit_in_comparison.rs | 181 +++++++++++++++++- 1 file changed, 178 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 6e89afd600be..793eca2f37f1 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -24,7 +24,9 @@ use arrow::datatypes::{ use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; -use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; +use datafusion_expr::{ + binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, +}; /// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: /// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. @@ -144,8 +146,57 @@ impl ExprRewriter for PreCastLitExprRewriter { // return the new binary op Ok(binary_expr(left, *op, right)) } - // TODO: optimize in list - // Expr::InList { .. } => {} + Expr::InList { + expr: left_expr, + list, + negated, + } => { + let left = left_expr.as_ref().clone(); + let left_type = left.get_type(&self.schema); + if left_type.is_err() { + // error data type + return Ok(expr); + } + let left_type = left_type?; + if !is_support_data_type(&left_type) { + // not supported data type + return Ok(expr); + } + let right_exprs = list + .iter() + .map(|right| { + let right_type = right.get_type(&self.schema)?; + if !is_support_data_type(&right_type) { + return Err(DataFusionError::Internal(format!( + "The type of list expr {} not support", + &right_type + ))); + } + match right { + Expr::Literal(right_lit_value) => { + let casted_scalar_value = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if let Some(value) = casted_scalar_value { + Ok(lit(value)) + } else { + Err(DataFusionError::Internal(format!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &left_type + ))) + } + } + other_expr => Err(DataFusionError::Internal(format!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ))), + } + }) + .collect::>>(); + match right_exprs { + Ok(right_exprs) => Ok(in_list(left, right_exprs, *negated)), + Err(_) => Ok(expr), + } + } // TODO: handle other expr type and dfs visit them _ => Ok(expr), } @@ -384,6 +435,129 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } + #[test] + fn test_not_list_cast_lit_comparison() { + let schema = expr_test_schema(); + // left type is not supported + // FLOAT32(C5) in ... + let expr_lt = col("c5").in_list( + vec![ + lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Int32(Some(12))), + ], + false, + ); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12)) + let expr_lt = col("c1").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Float32(Some(1.23))), + ], + false, + ); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // INT32(C1) in (INT64(99999999999), INT64(12)) + let expr_lt = col("c1").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(99999999999))), + ], + false, + ); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) + let expr_lt = col("c3").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Decimal128(Some(128), 12, 3)), + ], + false, + ); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + } + + #[test] + fn test_pre_list_cast_lit_comparison() { + let schema = expr_test_schema(); + // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) + let expr_lt = col("c1").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(24))), + ], + false, + ); + let expected = col("c1").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int32(Some(24))), + ], + false, + ); + assert_eq!(optimize_test(expr_lt, &schema), expected); + // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) + let expr_lt = col("c2").in_list( + vec![ + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int32(Some(14))), + ], + false, + ); + let expected = col("c2").in_list( + vec![ + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int64(Some(14))), + ], + false, + ); + + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // decimal test case + let expr_lt = col("c3").in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(24))), + lit(ScalarValue::Decimal128(Some(128), 10, 2)), + lit(ScalarValue::Decimal128(Some(1280), 10, 3)), + ], + false, + ); + let expected = col("c3").in_list( + vec![ + lit(ScalarValue::Decimal128(Some(1200), 18, 2)), + lit(ScalarValue::Decimal128(Some(2400), 18, 2)), + lit(ScalarValue::Decimal128(Some(128), 18, 2)), + lit(ScalarValue::Decimal128(Some(128), 18, 2)), + ], + false, + ); + assert_eq!(optimize_test(expr_lt, &schema), expected); + + // INT32(12) IN (.....) + let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), + ], + false, + ); + let expected = lit(ScalarValue::Int32(Some(12))).in_list( + vec![ + lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int32(Some(12))), + ], + false, + ); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + #[test] fn aliased() { let schema = expr_test_schema(); @@ -423,6 +597,7 @@ mod tests { DFField::new(None, "c2", DataType::Int64, false), DFField::new(None, "c3", DataType::Decimal128(18, 2), false), DFField::new(None, "c4", DataType::Decimal128(38, 37), false), + DFField::new(None, "c5", DataType::Float32, false), ], HashMap::new(), )