Skip to content

Commit

Permalink
Stop copying LogicalPlan and Exprs in PushDownLimit (apache#10508)
Browse files Browse the repository at this point in the history
* Stop copying LogicalPlan and Exprs in `PushDownLimit`

* Refine make_limit
  • Loading branch information
alamb authored and findepi committed Jul 16, 2024
1 parent ba42abe commit 28732ce
Showing 1 changed file with 149 additions and 126 deletions.
275 changes: 149 additions & 126 deletions datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ use std::sync::Arc;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::Result;
use datafusion_expr::logical_plan::{
Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union,
};
use datafusion_expr::CrossJoin;
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Result};
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan};

/// Optimization rule that tries to push down `LIMIT`.
///
Expand All @@ -46,131 +45,120 @@ impl PushDownLimit {
impl OptimizerRule for PushDownLimit {
fn try_optimize(
&self,
plan: &LogicalPlan,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
use std::cmp::min;
internal_err!("Should have called PushDownLimit::rewrite")
}

fn supports_rewrite(&self) -> bool {
true
}

let LogicalPlan::Limit(limit) = plan else {
return Ok(None);
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::Limit(mut limit) = plan else {
return Ok(Transformed::no(plan));
};

if let LogicalPlan::Limit(child) = &*limit.input {
// Merge the Parent Limit and the Child Limit.
let Limit { skip, fetch, input } = limit;
let input = input;

// Merge the Parent Limit and the Child Limit.
if let LogicalPlan::Limit(child) = input.as_ref() {
let (skip, fetch) =
combine_limit(limit.skip, limit.fetch, child.skip, child.fetch);

let plan = LogicalPlan::Limit(Limit {
skip,
fetch,
input: Arc::new((*child.input).clone()),
input: Arc::clone(&child.input),
});
return self
.try_optimize(&plan, _config)
.map(|opt_plan| opt_plan.or_else(|| Some(plan)));

// recursively reapply the rule on the new plan
return self.rewrite(plan, _config);
}

let Some(fetch) = limit.fetch else {
return Ok(None);
// no fetch to push, so return the original plan
let Some(fetch) = fetch else {
return Ok(Transformed::no(LogicalPlan::Limit(Limit {
skip,
fetch,
input,
})));
};
let skip = limit.skip;

match limit.input.as_ref() {
LogicalPlan::TableScan(scan) => {
let limit = if fetch != 0 { fetch + skip } else { 0 };
let new_fetch = scan.fetch.map(|x| min(x, limit)).or(Some(limit));
match unwrap_arc(input) {
LogicalPlan::TableScan(mut scan) => {
let rows_needed = if fetch != 0 { fetch + skip } else { 0 };
let new_fetch = scan
.fetch
.map(|x| min(x, rows_needed))
.or(Some(rows_needed));
if new_fetch == scan.fetch {
Ok(None)
original_limit(skip, fetch, LogicalPlan::TableScan(scan))
} else {
let new_input = LogicalPlan::TableScan(TableScan {
table_name: scan.table_name.clone(),
source: scan.source.clone(),
projection: scan.projection.clone(),
filters: scan.filters.clone(),
fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)),
projected_schema: scan.projected_schema.clone(),
});
plan.with_new_exprs(plan.expressions(), vec![new_input])
.map(Some)
// push limit into the table scan itself
scan.fetch = scan
.fetch
.map(|x| min(x, rows_needed))
.or(Some(rows_needed));
transformed_limit(skip, fetch, LogicalPlan::TableScan(scan))
}
}
LogicalPlan::Union(union) => {
let new_inputs = union
LogicalPlan::Union(mut union) => {
// push limits to each input of the union
union.inputs = union
.inputs
.iter()
.map(|x| {
Ok(Arc::new(LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(fetch + skip),
input: x.clone(),
})))
})
.collect::<Result<_>>()?;
let union = LogicalPlan::Union(Union {
inputs: new_inputs,
schema: union.schema.clone(),
});
plan.with_new_exprs(plan.expressions(), vec![union])
.map(Some)
.into_iter()
.map(|input| make_arc_limit(0, fetch + skip, input))
.collect();
transformed_limit(skip, fetch, LogicalPlan::Union(union))
}

LogicalPlan::CrossJoin(cross_join) => {
let new_left = LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(fetch + skip),
input: cross_join.left.clone(),
});
let new_right = LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(fetch + skip),
input: cross_join.right.clone(),
});
let new_cross_join = LogicalPlan::CrossJoin(CrossJoin {
left: Arc::new(new_left),
right: Arc::new(new_right),
schema: plan.schema().clone(),
});
plan.with_new_exprs(plan.expressions(), vec![new_cross_join])
.map(Some)
LogicalPlan::CrossJoin(mut cross_join) => {
// push limit to both inputs
cross_join.left = make_arc_limit(0, fetch + skip, cross_join.left);
cross_join.right = make_arc_limit(0, fetch + skip, cross_join.right);
transformed_limit(skip, fetch, LogicalPlan::CrossJoin(cross_join))
}

LogicalPlan::Join(join) => {
if let Some(new_join) = push_down_join(join, fetch + skip) {
let inputs = vec![LogicalPlan::Join(new_join)];
plan.with_new_exprs(plan.expressions(), inputs).map(Some)
} else {
Ok(None)
}
}
LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip)
.update_data(|join| {
make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join)))
})),

LogicalPlan::Sort(sort) => {
LogicalPlan::Sort(mut sort) => {
let new_fetch = {
let sort_fetch = skip + fetch;
Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch))
};
if new_fetch == sort.fetch {
Ok(None)
original_limit(skip, fetch, LogicalPlan::Sort(sort))
} else {
let new_sort = LogicalPlan::Sort(Sort {
expr: sort.expr.clone(),
input: sort.input.clone(),
fetch: new_fetch,
});
plan.with_new_exprs(plan.expressions(), vec![new_sort])
.map(Some)
sort.fetch = new_fetch;
limit.input = Arc::new(LogicalPlan::Sort(sort));
Ok(Transformed::yes(LogicalPlan::Limit(limit)))
}
}
child_plan @ (LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_)) => {
LogicalPlan::Projection(mut proj) => {
// commute
let new_limit = plan.with_new_exprs(
plan.expressions(),
vec![child_plan.inputs()[0].clone()],
)?;
child_plan
.with_new_exprs(child_plan.expressions(), vec![new_limit])
.map(Some)
limit.input = Arc::clone(&proj.input);
let new_limit = LogicalPlan::Limit(limit);
proj.input = Arc::new(new_limit);
Ok(Transformed::yes(LogicalPlan::Projection(proj)))
}
_ => Ok(None),
LogicalPlan::SubqueryAlias(mut subquery_alias) => {
// commute
limit.input = Arc::clone(&subquery_alias.input);
let new_limit = LogicalPlan::Limit(limit);
subquery_alias.input = Arc::new(new_limit);
Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias)))
}
input => original_limit(skip, fetch, input),
}
}

Expand All @@ -183,6 +171,61 @@ impl OptimizerRule for PushDownLimit {
}
}

/// Wrap the input plan with a limit node
///
/// Original:
/// ```text
/// input
/// ```
///
/// Return
/// ```text
/// Limit: skip=skip, fetch=fetch
/// input
/// ```
fn make_limit(skip: usize, fetch: usize, input: Arc<LogicalPlan>) -> LogicalPlan {
LogicalPlan::Limit(Limit {
skip,
fetch: Some(fetch),
input,
})
}

/// Wrap the input plan with a limit node
fn make_arc_limit(
skip: usize,
fetch: usize,
input: Arc<LogicalPlan>,
) -> Arc<LogicalPlan> {
Arc::new(make_limit(skip, fetch, input))
}

/// Returns the original limit (non transformed)
fn original_limit(
skip: usize,
fetch: usize,
input: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
Ok(Transformed::no(LogicalPlan::Limit(Limit {
skip,
fetch: Some(fetch),
input: Arc::new(input),
})))
}

/// Returns the a transformed limit
fn transformed_limit(
skip: usize,
fetch: usize,
input: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
Ok(Transformed::yes(LogicalPlan::Limit(Limit {
skip,
fetch: Some(fetch),
input: Arc::new(input),
})))
}

/// Combines two limits into a single
///
/// Returns the combined limit `(skip, fetch)`
Expand Down Expand Up @@ -255,14 +298,15 @@ fn combine_limit(
(combined_skip, combined_fetch)
}

fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
/// Adds a limit to the inputs of a join, if possible
fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
use JoinType::*;

fn is_no_join_condition(join: &Join) -> bool {
join.on.is_empty() && join.filter.is_none()
}

let (left_limit, right_limit) = if is_no_join_condition(join) {
let (left_limit, right_limit) = if is_no_join_condition(&join) {
match join.join_type {
Left | Right | Full => (Some(limit), Some(limit)),
LeftAnti | LeftSemi => (Some(limit), None),
Expand All @@ -277,37 +321,16 @@ fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
}
};

match (left_limit, right_limit) {
(None, None) => None,
_ => {
let left = match left_limit {
Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(limit),
input: join.left.clone(),
})),
None => join.left.clone(),
};
let right = match right_limit {
Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(limit),
input: join.right.clone(),
})),
None => join.right.clone(),
};
Some(Join {
left,
right,
on: join.on.clone(),
filter: join.filter.clone(),
join_type: join.join_type,
join_constraint: join.join_constraint,
schema: join.schema.clone(),
null_equals_null: join.null_equals_null,
})
}
if left_limit.is_none() && right_limit.is_none() {
return Transformed::no(join);
}
if let Some(limit) = left_limit {
join.left = make_arc_limit(0, limit, join.left);
}
if let Some(limit) = right_limit {
join.right = make_arc_limit(0, limit, join.right);
}
Transformed::yes(join)
}

#[cfg(test)]
Expand Down

0 comments on commit 28732ce

Please sign in to comment.