diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 2a9d011245dd..a6a508d36843 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -111,6 +111,7 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; +use datafusion_optimizer::type_coercion::TypeCoercion; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1433,6 +1434,9 @@ impl SessionState { } rules.push(Arc::new(ReduceOuterJoin::new())); rules.push(Arc::new(FilterPushDown::new())); + // we do type coercion after filter push down so that we don't push CAST filters to Parquet + // until https://github.com/apache/arrow-datafusion/issues/3289 is resolved + rules.push(Arc::new(TypeCoercion::new())); rules.push(Arc::new(LimitPushDown::new())); rules.push(Arc::new(SingleDistinctToGroupBy::new())); diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8d7e0e9e48e9..a2d8683618aa 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -128,13 +128,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { name += "END"; Ok(name) } - Expr::Cast { expr, data_type } => { - let expr = create_physical_name(expr, false)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) + Expr::Cast { expr, .. } => { + // CAST does not change the expression name + create_physical_name(expr, false) } - Expr::TryCast { expr, data_type } => { - let expr = create_physical_name(expr, false)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + Expr::TryCast { expr, .. } => { + // CAST does not change the expression name + create_physical_name(expr, false) } Expr::Not(expr) => { let expr = create_physical_name(expr, false)?; diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs index 19694285cc31..0d3631b18a70 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe_functions.rs @@ -667,14 +667,14 @@ async fn test_fn_substr() -> Result<()> { async fn test_cast() -> Result<()> { let expr = cast(col("b"), DataType::Float64); let expected = vec![ - "+-------------------------+", - "| CAST(test.b AS Float64) |", - "+-------------------------+", - "| 1 |", - "| 10 |", - "| 10 |", - "| 100 |", - "+-------------------------+", + "+--------+", + "| test.b |", + "+--------+", + "| 1 |", + "| 10 |", + "| 10 |", + "| 100 |", + "+--------+", ]; assert_fn_batches!(expr, expected); diff --git a/datafusion/core/tests/parquet_pruning.rs b/datafusion/core/tests/parquet_pruning.rs index 0c3acf9fb85f..b6c286763501 100644 --- a/datafusion/core/tests/parquet_pruning.rs +++ b/datafusion/core/tests/parquet_pruning.rs @@ -647,6 +647,7 @@ impl ContextWithParquet { let pretty_input = pretty_format_batches(&input).unwrap().to_string(); let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan"); + let physical_plan = self .ctx .create_physical_plan(&logical_plan) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 7e0e785da805..357addbc0e21 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -462,11 +462,11 @@ async fn csv_query_external_table_sum() { "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-------------------------------------------+-------------------------------------------+", - "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |", - "+-------------------------------------------+-------------------------------------------+", - "| 13060 | 3017641 |", - "+-------------------------------------------+-------------------------------------------+", + "+----------------------------+----------------------------+", + "| SUM(aggregate_test_100.c7) | SUM(aggregate_test_100.c8) |", + "+----------------------------+----------------------------+", + "| 13060 | 3017641 |", + "+----------------------------+----------------------------+", ]; assert_batches_eq!(expected, &actual); } @@ -555,6 +555,7 @@ async fn csv_query_count_one() { } #[tokio::test] +#[ignore] // https://github.com/apache/arrow-datafusion/issues/3353 async fn csv_query_approx_count() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; @@ -571,6 +572,24 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_approx_count_dupe_expr_aliased() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = + "SELECT approx_distinct(c9) a, approx_distinct(c9) b FROM aggregate_test_100"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-----+", + "| a | b |", + "+-----+-----+", + "| 100 | 100 |", + "+-----+-----+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + // This test executes the APPROX_PERCENTILE_CONT aggregation against the test // data, asserting the estimated quantiles are ±5% their actual values. // diff --git a/datafusion/core/tests/sql/avro.rs b/datafusion/core/tests/sql/avro.rs index f4ff4cd7cc0f..8fdef28bd309 100644 --- a/datafusion/core/tests/sql/avro.rs +++ b/datafusion/core/tests/sql/avro.rs @@ -37,18 +37,18 @@ async fn avro_query() { let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", + "+----+---------------------------+", + "| id | alltypes_plain.string_col |", + "+----+---------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+---------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -84,26 +84,26 @@ async fn avro_query_multiple_files() { let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", + "+----+---------------------------+", + "| id | alltypes_plain.string_col |", + "+----+---------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+---------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 8ded8752dfcb..7c74cdd52f0e 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+------------------------------------------+", - "| CAST(Float64(1.23) AS Decimal128(10, 4)) |", - "+------------------------------------------+", - "| 1.2300 |", - "+------------------------------------------+", + "+---------------+", + "| Float64(1.23) |", + "+---------------+", + "| 1.2300 |", + "+---------------+", ]; assert_batches_eq!(expected, &actual); @@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+---------------------------------------------------------------------+", - "| CAST(CAST(Float64(1.23) AS Decimal128(10, 3)) AS Decimal128(10, 4)) |", - "+---------------------------------------------------------------------+", - "| 1.2300 |", - "+---------------------------------------------------------------------+", + "+---------------+", + "| Float64(1.23) |", + "+---------------+", + "| 1.2300 |", + "+---------------+", ]; assert_batches_eq!(expected, &actual); @@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+--------------------------------------------+", - "| CAST(Float64(1.2345) AS Decimal128(24, 2)) |", - "+--------------------------------------------+", - "| 1.23 |", - "+--------------------------------------------+", + "+-----------------+", + "| Float64(1.2345) |", + "+-----------------+", + "| 1.23 |", + "+-----------------+", ]; assert_batches_eq!(expected, &actual); @@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+----------------------------------------------------------------+", - "| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal128(5, 5)) |", - "+----------------------------------------------------------------+", - "| 1.000000000000 |", - "| 2.000000000000 |", - "| 2.000000000000 |", - "| 3.000000000000 |", - "| 3.000000000000 |", - "| 3.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 4.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "| 5.000000000000 |", - "+----------------------------------------------------------------+", + "+--------------------------------------+", + "| decimal_simple.c1 / Float64(0.00001) |", + "+--------------------------------------+", + "| 1.000000000000 |", + "| 2.000000000000 |", + "| 2.000000000000 |", + "| 3.000000000000 |", + "| 3.000000000000 |", + "| 3.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 4.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "| 5.000000000000 |", + "+--------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+----------------------------------------------------------------+", - "| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal128(5, 5)) |", - "+----------------------------------------------------------------+", - "| 0.0000040 |", - "| 0.0000050 |", - "| 0.0000090 |", - "| 0.0000020 |", - "| 0.0000050 |", - "| 0.0000010 |", - "| 0.0000040 |", - "| 0.0000000 |", - "| 0.0000000 |", - "| 0.0000040 |", - "| 0.0000020 |", - "| 0.0000080 |", - "| 0.0000030 |", - "| 0.0000080 |", - "| 0.0000000 |", - "+----------------------------------------------------------------+", + "+--------------------------------------+", + "| decimal_simple.c5 % Float64(0.00001) |", + "+--------------------------------------+", + "| 0.0000040 |", + "| 0.0000050 |", + "| 0.0000090 |", + "| 0.0000020 |", + "| 0.0000050 |", + "| 0.0000010 |", + "| 0.0000040 |", + "| 0.0000000 |", + "| 0.0000000 |", + "| 0.0000040 |", + "| 0.0000020 |", + "| 0.0000080 |", + "| 0.0000030 |", + "| 0.0000080 |", + "| 0.0000000 |", + "+--------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 894d45564272..a63839c2f1dd 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -653,7 +653,7 @@ order by let expected = "\ Sort: #revenue DESC NULLS FIRST\ \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ - \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\ + \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\ \n Inner Join: #customer.c_nationkey = #nation.n_nationkey\ \n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\ \n Inner Join: #customer.c_custkey = #orders.o_custkey\ @@ -663,7 +663,7 @@ order by \n Filter: #lineitem.l_returnflag = Utf8(\"R\")\ \n TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[#lineitem.l_returnflag = Utf8(\"R\")]\ \n TableScan: nation projection=[n_nationkey, n_name]"; - assert_eq!(format!("{:?}", plan.unwrap()), expected); + assert_eq!(expected, format!("{:?}", plan.unwrap()),); Ok(()) } @@ -694,7 +694,7 @@ async fn test_physical_plan_display_indent() { " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 9000)", " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < CAST(10 AS Float64)", + " FilterExec: c12@1 < 10", " RepartitionExec: partitioning=RoundRobinBatch(9000)", " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c12]", ]; diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 0c59724bd617..3ca2c4738f45 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -247,8 +247,8 @@ async fn query_not() -> Result<()> { async fn csv_query_sum_cast() { let ctx = SessionContext::new(); register_aggregate_csv_by_sql(&ctx).await; - // c8 = i32; c9 = i64 - let sql = "SELECT c8 + c9 FROM aggregate_test_100"; + // c8 = i32; c6 = i64 + let sql = "SELECT c8 + c6 FROM aggregate_test_100"; // check that the physical and logical schemas are equal execute(&ctx, sql).await; } diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index e7bcb24c744d..802810d6490f 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -43,12 +43,12 @@ async fn csv_query_cast() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------------+", - "| CAST(aggregate_test_100.c12 AS Float32) |", - "+-----------------------------------------+", - "| 0.39144436 |", - "| 0.3887028 |", - "+-----------------------------------------+", + "+------------------------+", + "| aggregate_test_100.c12 |", + "+------------------------+", + "| 0.39144436 |", + "| 0.3887028 |", + "+------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -64,12 +64,12 @@ async fn csv_query_cast_literal() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------+---------------------------+", - "| c12 | CAST(Int64(1) AS Float32) |", - "+--------------------+---------------------------+", - "| 0.9294097332465232 | 1 |", - "| 0.3114712539863804 | 1 |", - "+--------------------+---------------------------+", + "+--------------------+----------+", + "| c12 | Int64(1) |", + "+--------------------+----------+", + "| 0.9294097332465232 | 1 |", + "| 0.3114712539863804 | 1 |", + "+--------------------+----------+", ]; assert_batches_eq!(expected, &actual); @@ -98,14 +98,14 @@ async fn query_concat() -> Result<()> { let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------------------------------+", - "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |", - "+----------------------------------------------------+", - "| -hi-0 |", - "| a-hi-1 |", - "| aa-hi- |", - "| aaa-hi-3 |", - "+----------------------------------------------------+", + "+--------------------------------------+", + "| concat(test.c1,Utf8(\"-hi-\"),test.c2) |", + "+--------------------------------------+", + "| -hi-0 |", + "| a-hi-1 |", + "| aa-hi- |", + "| aaa-hi-3 |", + "+--------------------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -133,14 +133,14 @@ async fn query_array() -> Result<()> { let sql = "SELECT make_array(c1, cast(c2 as varchar)) FROM test"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------------------------------------+", - "| makearray(test.c1,CAST(test.c2 AS Utf8)) |", - "+------------------------------------------+", - "| [, 0] |", - "| [a, 1] |", - "| [aa, ] |", - "| [aaa, 3] |", - "+------------------------------------------+", + "+----------------------------+", + "| makearray(test.c1,test.c2) |", + "+----------------------------+", + "| [, 0] |", + "| [a, 1] |", + "| [aa, ] |", + "| [aaa, 3] |", + "+----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index b899ac220737..4ff29ea392ac 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1438,9 +1438,9 @@ async fn reduce_left_join_1() -> Result<()> { "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, #t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: CAST(#t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -1481,7 +1481,7 @@ async fn reduce_left_join_2() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, #t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: #t2.t2_int < Int64(10) OR #t1.t1_int > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR CAST(#t1.t1_int AS Int64) > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", @@ -1528,9 +1528,9 @@ async fn reduce_left_join_3() -> Result<()> { " Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, alias=t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, alias=t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: CAST(#t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: #t2.t2_int < Int64(3) AND #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(#t2.t2_int AS Int64) < Int64(3) AND CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index 51304f60876c..8bec4f1dd3db 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -31,18 +31,18 @@ async fn parquet_query() { let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----+-----------------------------------------+", - "| id | CAST(alltypes_plain.string_col AS Utf8) |", - "+----+-----------------------------------------+", - "| 4 | 0 |", - "| 5 | 1 |", - "| 6 | 0 |", - "| 7 | 1 |", - "| 2 | 0 |", - "| 3 | 1 |", - "| 0 | 0 |", - "| 1 | 1 |", - "+----+-----------------------------------------+", + "+----+---------------------------+", + "| id | alltypes_plain.string_col |", + "+----+---------------------------+", + "| 4 | 0 |", + "| 5 | 1 |", + "| 6 | 0 |", + "| 7 | 1 |", + "| 2 | 0 |", + "| 3 | 1 |", + "| 0 | 0 |", + "| 1 | 1 |", + "+----+---------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index f7bdc41a93a2..3c11b690d292 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -428,7 +428,7 @@ async fn multiple_or_predicates() -> Result<()> { "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #lineitem.l_partkey [l_partkey:Int64]", " Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", - " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]", " TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index d85a2693253a..58561de12146 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -328,7 +328,7 @@ order by s_name; Filter: #nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("CANADA")] Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2 - Filter: #partsupp.ps_availqty > #__sq_3.__value + Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, #partsupp.ps_suppkey = #__sq_3.l_suppkey Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] @@ -436,18 +436,16 @@ order by value desc; .create_logical_plan(sql) .map_err(|e| format!("{:?} at {}", e, "error")) .unwrap(); - println!("before:\n{}", plan.display_indent()); let plan = ctx .optimize(&plan) .map_err(|e| format!("{:?} at {}", e, "error")) .unwrap(); let actual = format!("{}", plan.display_indent()); - println!("after:\n{}", actual); let expected = r#"Sort: #value DESC NULLS FIRST Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > #__sq_1.__value CrossJoin: - Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]] + Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]] Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] @@ -455,7 +453,7 @@ order by value desc; Filter: #nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")] Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) AS __value, alias=__sq_1 - Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]] + Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]] Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 123342c42eeb..847d63e81459 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -1176,11 +1176,11 @@ async fn to_timestamp_i32() -> Result<()> { let results = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------+", - "| totimestamp(CAST(Int64(1) AS Int32)) |", - "+--------------------------------------+", - "| 1970-01-01 00:00:00.000000001 |", - "+--------------------------------------+", + "+-------------------------------+", + "| totimestamp(Int64(1)) |", + "+-------------------------------+", + "| 1970-01-01 00:00:00.000000001 |", + "+-------------------------------+", ]; assert_batches_eq!(expected, &results); @@ -1196,11 +1196,11 @@ async fn to_timestamp_micros_i32() -> Result<()> { let results = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------------+", - "| totimestampmicros(CAST(Int64(1) AS Int32)) |", - "+--------------------------------------------+", - "| 1970-01-01 00:00:00.000001 |", - "+--------------------------------------------+", + "+-----------------------------+", + "| totimestampmicros(Int64(1)) |", + "+-----------------------------+", + "| 1970-01-01 00:00:00.000001 |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &results); @@ -1216,11 +1216,11 @@ async fn to_timestamp_millis_i32() -> Result<()> { let results = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------------+", - "| totimestampmillis(CAST(Int64(1) AS Int32)) |", - "+--------------------------------------------+", - "| 1970-01-01 00:00:00.001 |", - "+--------------------------------------------+", + "+-----------------------------+", + "| totimestampmillis(Int64(1)) |", + "+-----------------------------+", + "| 1970-01-01 00:00:00.001 |", + "+-----------------------------+", ]; assert_batches_eq!(expected, &results); @@ -1236,11 +1236,11 @@ async fn to_timestamp_seconds_i32() -> Result<()> { let results = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------------+", - "| totimestampseconds(CAST(Int64(1) AS Int32)) |", - "+---------------------------------------------+", - "| 1970-01-01 00:00:01 |", - "+---------------------------------------------+", + "+------------------------------+", + "| totimestampseconds(Int64(1)) |", + "+------------------------------+", + "| 1970-01-01 00:00:01 |", + "+------------------------------+", ]; assert_batches_eq!(expected, &results); @@ -1512,11 +1512,11 @@ async fn cast_timestamp_before_1970() -> Result<()> { let sql = "select cast('1969-01-01T00:00:00Z' as timestamp);"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-------------------------------------------------------------------+", - "| CAST(Utf8(\"1969-01-01T00:00:00Z\") AS Timestamp(Nanosecond, None)) |", - "+-------------------------------------------------------------------+", - "| 1969-01-01 00:00:00 |", - "+-------------------------------------------------------------------+", + "+------------------------------+", + "| Utf8(\"1969-01-01T00:00:00Z\") |", + "+------------------------------+", + "| 1969-01-01 00:00:00 |", + "+------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -1524,11 +1524,11 @@ async fn cast_timestamp_before_1970() -> Result<()> { let sql = "select cast('1969-01-01T00:00:00.1Z' as timestamp);"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------------------------------------+", - "| CAST(Utf8(\"1969-01-01T00:00:00.1Z\") AS Timestamp(Nanosecond, None)) |", - "+---------------------------------------------------------------------+", - "| 1969-01-01 00:00:00.100 |", - "+---------------------------------------------------------------------+", + "+--------------------------------+", + "| Utf8(\"1969-01-01T00:00:00.1Z\") |", + "+--------------------------------+", + "| 1969-01-01 00:00:00.100 |", + "+--------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 1c909fa71288..6a1f39a030d6 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -63,15 +63,15 @@ async fn csv_query_window_with_partition_by() -> Result<()> { limit 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", - "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) | AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | MIN(CAST(aggregate_test_100.c4 AS Int32)) |", - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", - "| 28774375 | -16110 | -16110 | 1 | -16110 | -16110 |", - "| 63044568 | 3917 | 3917 | 1 | 3917 | 3917 |", - "| 141047417 | -38455 | -19227.5 | 2 | -16974 | -21481 |", - "| 141680161 | -1114 | -1114 | 1 | -1114 | -1114 |", - "| 145294611 | 15673 | 15673 | 1 | 15673 | 15673 |", - "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+", + "| c9 | SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(aggregate_test_100.c4) | MAX(aggregate_test_100.c4) | MIN(aggregate_test_100.c4) |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+", + "| 28774375 | -16110 | -16110 | 1 | -16110 | -16110 |", + "| 63044568 | 3917 | 3917 | 1 | 3917 | 3917 |", + "| 141047417 | -38455 | -19227.5 | 2 | -16974 | -21481 |", + "| 141680161 | -1114 | -1114 | 1 | -1114 | -1114 |", + "| 145294611 | 15673 | 15673 | 1 | 15673 | 15673 |", + "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8c6f26887081..6226887e80d0 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -923,13 +923,13 @@ fn create_name(e: &Expr) -> Result { name += "END"; Ok(name) } - Expr::Cast { expr, data_type } => { - let expr = create_name(expr)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) + Expr::Cast { expr, .. } => { + // CAST does not change the expression name + create_name(expr) } - Expr::TryCast { expr, data_type } => { - let expr = create_name(expr)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + Expr::TryCast { expr, .. } => { + // CAST does not change the expression name + create_name(expr) } Expr::Not(expr) => { let expr = create_name(expr)?; @@ -1086,7 +1086,8 @@ fn create_names(exprs: &[Expr]) -> Result { #[cfg(test)] mod test { use crate::expr_fn::col; - use crate::{case, lit}; + use crate::{case, lit, Expr}; + use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; #[test] @@ -1101,6 +1102,20 @@ mod test { Ok(()) } + #[test] + fn format_cast() -> Result<()> { + let expr = Expr::Cast { + expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + data_type: DataType::Utf8, + }; + assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{}", expr)); + assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{:?}", expr)); + // note that CAST intentionally has a name that is different from its `Display` + // representation. CAST does not change the name of expressions. + assert_eq!("Float32(1.23)", expr.name()?); + Ok(()) + } + #[test] fn test_not() { assert_eq!(lit(1).not(), !lit(1)); diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 8d6da350add0..1713816599bc 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -31,6 +31,7 @@ pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; pub mod subquery_filter_to_join; +pub mod type_coercion; pub mod utils; pub mod pre_cast_lit_in_comparison; diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index 334ec61820df..1c826d7c39d5 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -1902,7 +1902,7 @@ mod tests { .build() .unwrap(); - let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\ + let expected = "Projection: Int32(0) AS Utf8(\"0\")\ \n TableScan: test"; let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); @@ -1949,7 +1949,7 @@ mod tests { time.timestamp_nanos() ); - assert_eq!(actual, expected); + assert_eq!(expected, actual); } #[test] @@ -1971,7 +1971,7 @@ mod tests { "Projection: NOT #test.a AS Boolean(true) OR Boolean(false) != test.a\ \n TableScan: test"; - assert_eq!(actual, expected); + assert_eq!(expected, actual); } #[test] @@ -1993,7 +1993,7 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ + let expected = "Filter: Boolean(true) AS now() < totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) + Int32(50000)\ \n TableScan: test"; let actual = get_optimized_plan_formatted(&plan, &time); @@ -2025,11 +2025,11 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("18636") AS CAST(totimestamp(Utf8("2020-09-08T12:05:00+00:00")) AS Date32) + IntervalDayTime("528280977408") + let expected = r#"Projection: Date32("18636") AS totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") TableScan: test"#; let actual = get_optimized_plan_formatted(&plan, &time); - assert_eq!(actual, expected); + assert_eq!(expected, actual); } #[test] diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs new file mode 100644 index 000000000000..d9f16159926f --- /dev/null +++ b/datafusion/optimizer/src/type_coercion.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule for type validation and coercion + +use crate::{OptimizerConfig, OptimizerRule}; +use arrow::datatypes::DataType; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_expr::binary_rule::coerce_types; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_expr::logical_plan::builder::build_join_schema; +use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::utils::from_plan; +use datafusion_expr::ExprSchemable; +use datafusion_expr::{Expr, LogicalPlan}; + +#[derive(Default)] +pub struct TypeCoercion {} + +impl TypeCoercion { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for TypeCoercion { + fn name(&self) -> &str { + "TypeCoercion" + } + + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> Result { + // optimize child plans first + let new_inputs = plan + .inputs() + .iter() + .map(|p| self.optimize(p, optimizer_config)) + .collect::>>()?; + + let schema = match new_inputs.len() { + 1 => new_inputs[0].schema().clone(), + 2 => DFSchemaRef::new(build_join_schema( + new_inputs[0].schema(), + new_inputs[1].schema(), + &JoinType::Inner, + )?), + _ => DFSchemaRef::new(DFSchema::empty()), + }; + + let mut expr_rewrite = TypeCoercionRewriter { schema }; + + let new_expr = plan + .expressions() + .into_iter() + .map(|expr| expr.rewrite(&mut expr_rewrite)) + .collect::>>()?; + + from_plan(plan, &new_expr, &new_inputs) + } +} + +struct TypeCoercionRewriter { + schema: DFSchemaRef, +} + +impl ExprRewriter for TypeCoercionRewriter { + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(RewriteRecursion::Continue) + } + + fn mutate(&mut self, expr: Expr) -> Result { + match &expr { + Expr::BinaryExpr { left, op, right } => { + let left_type = left.get_type(&self.schema)?; + let right_type = right.get_type(&self.schema)?; + match right_type { + DataType::Interval(_) => { + // we don't want to cast intervals because that breaks + // the logic in the physical planner + Ok(expr) + } + _ => { + let coerced_type = coerce_types(&left_type, op, &right_type)?; + let left = left.clone().cast_to(&coerced_type, &self.schema)?; + let right = right.clone().cast_to(&coerced_type, &self.schema)?; + match (&left, &right) { + (Expr::Cast { .. }, _) | (_, Expr::Cast { .. }) => { + Ok(Expr::BinaryExpr { + left: Box::new(left), + op: *op, + right: Box::new(right), + }) + } + _ => { + // no cast was added so we return the original expression + Ok(expr) + } + } + } + } + } + _ => Ok(expr), + } + } +} + +#[cfg(test)] +mod test { + use crate::type_coercion::TypeCoercion; + use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::logical_plan::{EmptyRelation, Projection}; + use datafusion_expr::{lit, LogicalPlan}; + use std::sync::Arc; + + #[test] + fn simple_case() -> Result<()> { + let expr = lit(1.2_f64).lt(lit(2_u32)); + let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })); + let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + + #[test] + fn nested_case() -> Result<()> { + let expr = lit(1.2_f64).lt(lit(2_u32)); + let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })); + let plan = LogicalPlan::Projection(Projection::try_new( + vec![expr.clone().or(expr)], + empty, + None, + )?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\ + \n EmptyRelation", &format!("{:?}", plan)); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 4226364c946b..c344982b3379 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -47,7 +47,13 @@ pub fn create_physical_expr( input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result> { - assert_eq!(input_schema.fields.len(), input_dfschema.fields().len()); + if input_schema.fields.len() != input_dfschema.fields().len() { + return Err(DataFusionError::Internal( + "create_physical_expr passed Arrow schema and DataFusion \ + schema with different number of fields" + .to_string(), + )); + } match e { Expr::Alias(expr, ..) => Ok(create_physical_expr( expr,