From dafd957288371597902c6af9df3f31a31ac418a2 Mon Sep 17 00:00:00 2001 From: jakevin Date: Mon, 28 Nov 2022 23:02:13 +0800 Subject: [PATCH] add `with_new_inputs` (#4393) --- datafusion/expr/src/logical_plan/plan.rs | 9 ++++++++- datafusion/optimizer/src/limit_push_down.rs | 15 +++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 82e44986a7b1..ed5711a7fb57 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::{ - exprlist_to_fields, grouping_set_expr_count, grouping_set_to_exprlist, + exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -349,6 +349,13 @@ impl LogicalPlan { self.accept(&mut visitor)?; Ok(visitor.using_columns) } + + pub fn with_new_inputs( + &self, + inputs: &[LogicalPlan], + ) -> Result { + from_plan(self, &self.expressions(), inputs) + } } /// Trait that implements the [Visitor diff --git a/datafusion/optimizer/src/limit_push_down.rs b/datafusion/optimizer/src/limit_push_down.rs index 9dbda4653763..28b868fd643f 100644 --- a/datafusion/optimizer/src/limit_push_down.rs +++ b/datafusion/optimizer/src/limit_push_down.rs @@ -19,7 +19,6 @@ //! It will push down through projection, limits (taking the smaller limit) use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::Result; -use datafusion_expr::utils::from_plan; use datafusion_expr::{ logical_plan::{ Join, JoinType, Limit, LogicalPlan, Projection, Sort, TableScan, Union, @@ -131,7 +130,7 @@ impl OptimizerRule for LimitPushDown { fetch: scan.fetch.map(|x| std::cmp::min(x, limit)).or(Some(limit)), projected_schema: scan.projected_schema.clone(), }); - from_plan(plan, &plan.expressions(), &[new_input])? + plan.with_new_inputs(&[new_input])? } LogicalPlan::Projection(projection) => { @@ -164,7 +163,7 @@ impl OptimizerRule for LimitPushDown { inputs: new_inputs, schema: union.schema.clone(), }); - from_plan(plan, &plan.expressions(), &[union])? + plan.with_new_inputs(&[union])? } LogicalPlan::CrossJoin(cross_join) => { @@ -180,12 +179,12 @@ impl OptimizerRule for LimitPushDown { fetch: Some(fetch + skip), input: Arc::new(right.clone()), }); - let new_input = LogicalPlan::CrossJoin(CrossJoin { + let new_cross_join = LogicalPlan::CrossJoin(CrossJoin { left: Arc::new(new_left), right: Arc::new(new_right), schema: plan.schema().clone(), }); - from_plan(plan, &plan.expressions(), &[new_input])? + plan.with_new_inputs(&[new_cross_join])? } LogicalPlan::Join(join) => { @@ -195,19 +194,19 @@ impl OptimizerRule for LimitPushDown { JoinType::Right => push_down_join(join, None, Some(limit)), _ => push_down_join(join, None, None), }; - from_plan(plan, &plan.expressions(), &[new_join])? + plan.with_new_inputs(&[new_join])? } LogicalPlan::Sort(sort) => { let sort_fetch = skip + fetch; - let new_input = LogicalPlan::Sort(Sort { + let new_sort = LogicalPlan::Sort(Sort { expr: sort.expr.clone(), input: Arc::new((*sort.input).clone()), fetch: Some( sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch), ), }); - from_plan(plan, &plan.expressions(), &[new_input])? + plan.with_new_inputs(&[new_sort])? } _ => plan.clone(), };