From 8fa80e70d3cfe22335f3e5a9c4a5b91aab6d14ea Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 19 Dec 2023 16:40:34 +0100 Subject: [PATCH] - refactor `EnforceDistribution` using `transform_down_with_payload()` --- .../enforce_distribution.rs | 251 ++++++------------ 1 file changed, 88 insertions(+), 163 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index b54ec2d6a7f0..1d345bbfaeaf 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -191,17 +191,15 @@ impl EnforceDistribution { impl PhysicalOptimizerRule for EnforceDistribution { fn optimize( &self, - plan: Arc, + mut plan: Arc, config: &ConfigOptions, ) -> Result> { let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively - let plan_requirements = PlanWithKeyRequirements::new(plan); - let adjusted = - plan_requirements.transform_down_old(&adjust_input_keys_ordering)?; - adjusted.plan + plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?; + plan } else { // Run a bottom-up process plan.transform_up_old(&|plan| { @@ -269,12 +267,15 @@ impl PhysicalOptimizerRule for EnforceDistribution { /// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements /// 5) For other types of operators, by default, pushdown the parent requirements to children. /// +type RequiredKeyOrdering = Option>>; + fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, -) -> Result> { - let parent_required = requirements.required_key_ordering.clone(); - let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { + plan: &mut Arc, + required_key_ordering: RequiredKeyOrdering, +) -> Result<(TreeNodeRecursion, Vec)> { + let parent_required = required_key_ordering.unwrap_or_default().clone(); + let plan_any = plan.as_any(); + if let Some(HashJoinExec { left, right, on, @@ -299,13 +300,15 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), + let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( + plan.clone(), &parent_required, on, vec![], &join_constructor, - )?) + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) } PartitionMode::CollectLeft => { let new_right_request = match join_type { @@ -323,15 +326,14 @@ fn adjust_input_keys_ordering( }; // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) + Ok((TreeNodeRecursion::Continue, vec![None, new_right_request])) } PartitionMode::Auto => { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } } else if let Some(CrossJoinExec { left, .. }) = @@ -339,14 +341,13 @@ fn adjust_input_keys_ordering( { let left_columns_len = left.schema().fields().len(); // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ + Ok(( + TreeNodeRecursion::Continue, + vec![ None, shift_right_required(&parent_required, left_columns_len), ], - }) + )) } else if let Some(SortMergeJoinExec { left, right, @@ -368,26 +369,38 @@ fn adjust_input_keys_ordering( *null_equals_null, )?) as Arc) }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), + let (new_plan, request_key_ordering) = reorder_partitioned_join_keys( + plan.clone(), &parent_required, on, sort_options.clone(), &join_constructor, - )?) + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { if !parent_required.is_empty() { match aggregate_exec.mode() { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( - requirements.plan.clone(), - &parent_required, - aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + AggregateMode::FinalPartitioned => { + let (new_plan, request_key_ordering) = reorder_aggregate_keys( + plan.clone(), + &parent_required, + aggregate_exec, + )?; + *plan = new_plan; + Ok((TreeNodeRecursion::Continue, request_key_ordering)) + } + _ => Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )), } } else { // Keep everything unchanged - None + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } else if let Some(proj) = plan_any.downcast_ref::() { let expr = proj.expr(); @@ -396,34 +409,33 @@ fn adjust_input_keys_ordering( // Construct a mapping from new name to the the orginal Column let new_required = map_columns_before_projection(&parent_required, expr); if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) + Ok(( + TreeNodeRecursion::Continue, + vec![Some(new_required.clone())], + )) } else { // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } } else if plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + Ok(( + TreeNodeRecursion::Continue, + vec![None; plan.children().len()], + )) } else { // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) + let children_len = plan.children().len(); + Ok(( + TreeNodeRecursion::Continue, + vec![Some(parent_required.clone()); children_len], + )) + } } fn reorder_partitioned_join_keys( @@ -432,7 +444,7 @@ fn reorder_partitioned_join_keys( on: &[(Column, Column)], sort_options: Vec, join_constructor: &F, -) -> Result +) -> Result<(Arc, Vec)> where F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, { @@ -455,27 +467,21 @@ where new_sort_options.push(sort_options[new_positions[idx]]) } - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + Ok(( + join_constructor((new_join_on, new_sort_options))?, + vec![Some(left_keys), Some(right_keys)], + )) } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) + Ok((join_plan, vec![Some(left_keys), Some(right_keys)])) } } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ + Ok(( + join_plan, + vec![ Some(join_key_pairs.left_keys), Some(join_key_pairs.right_keys), ], - }) + )) } } @@ -483,7 +489,7 @@ fn reorder_aggregate_keys( agg_plan: Arc, parent_required: &[Arc], agg_exec: &AggregateExec, -) -> Result { +) -> Result<(Arc, Vec)> { let output_columns = agg_exec .group_by() .expr() @@ -501,11 +507,15 @@ fn reorder_aggregate_keys( || !agg_exec.group_by().null_expr().is_empty() || physical_exprs_equal(&output_exprs, parent_required) { - Ok(PlanWithKeyRequirements::new(agg_plan)) + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) } else { let new_positions = expected_expr_positions(&output_exprs, parent_required); match new_positions { - None => Ok(PlanWithKeyRequirements::new(agg_plan)), + None => { + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) + } Some(positions) => { let new_partial_agg = if let Some(agg_exec) = agg_exec.input().as_any().downcast_ref::() @@ -577,11 +587,13 @@ fn reorder_aggregate_keys( .push((Arc::new(Column::new(name, idx)) as _, name.clone())) } // TODO merge adjacent Projections if there are - Ok(PlanWithKeyRequirements::new(Arc::new( - ProjectionExec::try_new(proj_exprs, new_final_agg)?, - ))) + let new_plan = + Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?); + let request_key_ordering = vec![None; new_plan.children().len()]; + Ok((new_plan, request_key_ordering)) } else { - Ok(PlanWithKeyRequirements::new(agg_plan)) + let request_key_ordering = vec![None; agg_plan.children().len()]; + Ok((agg_plan, request_key_ordering)) } } } @@ -1539,93 +1551,6 @@ struct JoinKeyPairs { right_keys: Vec>, } -#[derive(Debug, Clone)] -struct PlanWithKeyRequirements { - plan: Arc, - /// Parent required key ordering - required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, -} - -impl PlanWithKeyRequirements { - fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { - plan, - required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], - } - } - - fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent, - request_key_ordering: vec![None; length], - } - }) - .collect() - } -} - -impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - self.children().iter().for_each_till_continue(f) - } - - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? - .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) - } - } - - fn transform_children(&mut self, f: &mut F) -> Result - where - F: FnMut(&mut Self) -> Result, - { - let mut children = self.children(); - if !children.is_empty() { - let tnr = children.iter_mut().for_each_till_continue(f)?; - let children_plans = children.into_iter().map(|c| c.plan).collect(); - self.plan = - with_new_children_if_necessary(self.plan.clone(), children_plans)?.into(); - Ok(tnr) - } else { - Ok(TreeNodeRecursion::Continue) - } - } -} - /// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on #[cfg(feature = "parquet")] #[cfg(test)]