diff --git a/datafusion/optimizer/src/limit_push_down.rs b/datafusion/optimizer/src/limit_push_down.rs index 2f121b881432..51b2bb4f0d37 100644 --- a/datafusion/optimizer/src/limit_push_down.rs +++ b/datafusion/optimizer/src/limit_push_down.rs @@ -24,6 +24,7 @@ use datafusion_expr::{ Join, JoinType, Limit, LogicalPlan, Projection, Sort, TableScan, Union, }, utils::from_plan, + CrossJoin, }; use std::sync::Arc; @@ -204,6 +205,38 @@ fn limit_push_down( schema: schema.clone(), })) } + ( + LogicalPlan::CrossJoin(cross_join), + Ancestor::FromLimit { + skip: ancestor_skip, + fetch: Some(ancestor_fetch), + .. + }, + ) => { + let left = &*cross_join.left; + let right = &*cross_join.right; + Ok(LogicalPlan::CrossJoin(CrossJoin { + left: Arc::new(limit_push_down( + _optimizer, + Ancestor::FromLimit { + skip: 0, + fetch: Some(ancestor_fetch + ancestor_skip), + }, + left, + _optimizer_config, + )?), + right: Arc::new(limit_push_down( + _optimizer, + Ancestor::FromLimit { + skip: 0, + fetch: Some(ancestor_fetch + ancestor_skip), + }, + right, + _optimizer_config, + )?), + schema: plan.schema().clone(), + })) + } ( LogicalPlan::Join(Join { join_type, .. }), Ancestor::FromLimit { @@ -394,6 +427,7 @@ mod test { Ok(()) } + #[test] fn limit_push_down_take_smaller_limit() -> Result<()> { let table_scan = test_table_scan()?; @@ -872,4 +906,44 @@ mod test { Ok(()) } + + #[test] + fn limit_push_down_cross_join() -> Result<()> { + let table_scan_1 = test_table_scan()?; + let table_scan_2 = test_table_scan_with_name("test2")?; + + let plan = LogicalPlanBuilder::from(table_scan_1) + .cross_join(&LogicalPlanBuilder::from(table_scan_2).build()?)? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n CrossJoin:\ + \n TableScan: test, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn skip_limit_push_down_cross_join() -> Result<()> { + let table_scan_1 = test_table_scan()?; + let table_scan_2 = test_table_scan_with_name("test2")?; + + let plan = LogicalPlanBuilder::from(table_scan_1) + .cross_join(&LogicalPlanBuilder::from(table_scan_2).build()?)? + .limit(1000, Some(1000))? + .build()?; + + let expected = "Limit: skip=1000, fetch=1000\ + \n CrossJoin:\ + \n TableScan: test, fetch=2000\ + \n TableScan: test2, fetch=2000"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } }