From 9fb97f3be842daef5600b52f31508436739918a3 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 16 Nov 2022 10:23:16 -0800 Subject: [PATCH 01/15] Enable inexact filters for predicate pushdown, add helper to get fitlers from TableScan struct --- dask_planner/src/sql/logical/table_scan.rs | 23 +++++++++++++++++----- dask_planner/src/sql/table.rs | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/dask_planner/src/sql/logical/table_scan.rs b/dask_planner/src/sql/logical/table_scan.rs index db0fbf599..8f7b3026e 100644 --- a/dask_planner/src/sql/logical/table_scan.rs +++ b/dask_planner/src/sql/logical/table_scan.rs @@ -1,7 +1,12 @@ -use datafusion_expr::logical_plan::TableScan; +use std::sync::Arc; + +use datafusion_expr::{logical_plan::TableScan, LogicalPlan}; use pyo3::prelude::*; -use crate::sql::{exceptions::py_type_err, logical}; +use crate::{ + expression::{py_expr_list, PyExpr}, + sql::exceptions::py_type_err, +}; #[pyclass(name = "TableScan", module = "dask_planner", subclass)] #[derive(Clone)] @@ -31,14 +36,22 @@ impl PyTableScan { fn contains_projections(&self) -> bool { self.table_scan.projection.is_some() } + + #[pyo3(name = "getFilters")] + fn scan_filters(&self) -> PyResult> { + py_expr_list( + &Arc::new(LogicalPlan::TableScan(self.table_scan.clone())), + &self.table_scan.filters, + ) + } } -impl TryFrom for PyTableScan { +impl TryFrom for PyTableScan { type Error = PyErr; - fn try_from(logical_plan: logical::LogicalPlan) -> Result { + fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { - logical::LogicalPlan::TableScan(table_scan) => Ok(PyTableScan { table_scan }), + LogicalPlan::TableScan(table_scan) => Ok(PyTableScan { table_scan }), _ => Err(py_type_err("unexpected plan")), } } diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index c44eec38c..721021dd9 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -62,7 +62,7 @@ impl TableSource for DaskTableSource { // to retain the Filter operator in the plan as well Ok(TableProviderFilterPushDown::Inexact) } else { - Ok(TableProviderFilterPushDown::Unsupported) + Ok(TableProviderFilterPushDown::Inexact) } } } From 06833362f6ed9a5aed4641d24da86ff77b511002 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 16 Nov 2022 10:24:03 -0800 Subject: [PATCH 02/15] Update table scan logic to add filters --- dask_sql/physical/rel/logical/table_scan.py | 41 +++++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 716e51dcd..13d6fe390 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -3,6 +3,8 @@ from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.physical.rel.logical.filter import filter_or_scalar +from dask_sql.physical.rex import RexConverter if TYPE_CHECKING: import dask_sql @@ -13,14 +15,13 @@ class DaskTableScanPlugin(BaseRelPlugin): """ - A DaskTableScal is the main ingredient: it will get the data + A DaskTableScan is the main ingredient: it will get the data from the database. It is always used, when the SQL looks like SELECT .... FROM table .... We need to get the dask dataframe from the registered tables and return the requested columns from it. - Calcite will always refer to columns via index. """ class_name = "TableScan" @@ -41,12 +42,22 @@ def convert( schema_name, table_name = [n.lower() for n in context.fqn(dask_table)] dc = context.schema[schema_name].tables[table_name] - df = dc.df + + dc = self._apply_projections(table_scan, dask_table, dc) + dc = self._apply_filters(table_scan, rel, dc, context) + cc = dc.column_container + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(dc.df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc + def _apply_projections(self, table_scan, dask_table, dc): # If the 'TableScan' instance contains projected columns only retrieve those columns # otherwise get all projected columns from the 'Projection' instance, which is contained # in the 'RelDataType' instance, aka 'row_type' + df = dc.df + cc = dc.column_container if table_scan.containsProjections(): field_specifications = ( table_scan.getTableScanProjects() @@ -56,9 +67,23 @@ def convert( field_specifications = [ str(f) for f in dask_table.getRowType().getFieldNames() ] - cc = cc.limit_to(field_specifications) - cc = self.fix_column_to_row_type(cc, rel.getRowType()) - dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) - return dc + return DataContainer(df, cc) + + def _apply_filters(self, table_scan, rel, dc, context): + df = dc.df + cc = dc.column_container + filters = table_scan.getFilters() + # All partial filters here are applied in conjunction (&) + df_condition = None + for filter in filters: + filter_condition = RexConverter.convert(rel, filter, dc, context=context) + df_condition = ( + filter_condition + if df_condition is None + else (df_condition & filter_condition) + ) + if len(filters) > 0: + df = filter_or_scalar(df, df_condition) + + return DataContainer(df, cc) From 9a9fa1ed7a78162b3dcc019cbe3bc6cef59b93ff Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Fri, 18 Nov 2022 11:38:58 -0800 Subject: [PATCH 03/15] Update PyTableScan to include input schema --- dask_planner/src/sql/logical/table_scan.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/dask_planner/src/sql/logical/table_scan.rs b/dask_planner/src/sql/logical/table_scan.rs index 8f7b3026e..537f011cc 100644 --- a/dask_planner/src/sql/logical/table_scan.rs +++ b/dask_planner/src/sql/logical/table_scan.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use datafusion_common::DFSchema; use datafusion_expr::{logical_plan::TableScan, LogicalPlan}; use pyo3::prelude::*; @@ -12,6 +13,7 @@ use crate::{ #[derive(Clone)] pub struct PyTableScan { pub(crate) table_scan: TableScan, + input: Arc, } #[pymethods] @@ -39,10 +41,7 @@ impl PyTableScan { #[pyo3(name = "getFilters")] fn scan_filters(&self) -> PyResult> { - py_expr_list( - &Arc::new(LogicalPlan::TableScan(self.table_scan.clone())), - &self.table_scan.filters, - ) + py_expr_list(&self.input, &self.table_scan.filters) } } @@ -51,7 +50,20 @@ impl TryFrom for PyTableScan { fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { - LogicalPlan::TableScan(table_scan) => Ok(PyTableScan { table_scan }), + LogicalPlan::TableScan(table_scan) => { + // Create an input logical plan that's identical to the table scan with schema from the table source + let mut input = table_scan.clone(); + input.projected_schema = DFSchema::try_from_qualified_schema( + &table_scan.table_name, + &table_scan.source.schema(), + ) + .map_or(input.projected_schema, Arc::new); + + Ok(PyTableScan { + table_scan, + input: Arc::new(LogicalPlan::TableScan(input)), + }) + } _ => Err(py_type_err("unexpected plan")), } } From 8d1fd4a0b82eafc399a9cb54c69352dcd8ec23db Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Fri, 18 Nov 2022 11:42:56 -0800 Subject: [PATCH 04/15] Update DaskTableSource to allow filtering on all expr's --- dask_planner/src/sql/table.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 721021dd9..76c24af57 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -56,23 +56,20 @@ impl TableSource for DaskTableSource { // TODO this should return Exact but we cannot make that change until we // are actually pushing the TableScan filters down to the reader because // returning Exact here would remove the Filter from the plan - Ok(TableProviderFilterPushDown::Inexact) + Ok(TableProviderFilterPushDown::Exact) } else if filters.iter().any(|f| is_supported_push_down_expr(f)) { // we can partially apply the filter in the TableScan but we need // to retain the Filter operator in the plan as well Ok(TableProviderFilterPushDown::Inexact) } else { - Ok(TableProviderFilterPushDown::Inexact) + Ok(TableProviderFilterPushDown::Unsupported) } } } -fn is_supported_push_down_expr(expr: &Expr) -> bool { - match expr { - // for now, we just attempt to push down simple IS NOT NULL filters on columns - Expr::IsNotNull(ref a) => matches!(a.as_ref(), Expr::Column(_)), - _ => false, - } +fn is_supported_push_down_expr(_expr: &Expr) -> bool { + // For now we support all kinds of expr's at this level + true } #[pyclass(name = "DaskStatistics", module = "dask_planner", subclass)] From 6b0b5447f0b135182270bc82049229115a8a59e3 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Fri, 18 Nov 2022 11:43:58 -0800 Subject: [PATCH 05/15] Change order to apply filters before projections --- dask_sql/physical/rel/logical/table_scan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 13d6fe390..7a3a8c870 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -43,8 +43,9 @@ def convert( dc = context.schema[schema_name].tables[table_name] - dc = self._apply_projections(table_scan, dask_table, dc) + # Apply filter before projections since filter columns may not be in projections dc = self._apply_filters(table_scan, rel, dc, context) + dc = self._apply_projections(table_scan, dask_table, dc) cc = dc.column_container cc = self.fix_column_to_row_type(cc, rel.getRowType()) From 5a4b291ffb07cc7b027e3c52c53886bd2557d44e Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 28 Nov 2022 06:40:52 -0800 Subject: [PATCH 06/15] Clean up filter conjuction application --- .../src/sql/optimizer/filter_push_down.rs | 641 ------------------ dask_sql/physical/rel/logical/table_scan.py | 17 +- 2 files changed, 10 insertions(+), 648 deletions(-) delete mode 100644 dask_planner/src/sql/optimizer/filter_push_down.rs diff --git a/dask_planner/src/sql/optimizer/filter_push_down.rs b/dask_planner/src/sql/optimizer/filter_push_down.rs deleted file mode 100644 index ac0429774..000000000 --- a/dask_planner/src/sql/optimizer/filter_push_down.rs +++ /dev/null @@ -1,641 +0,0 @@ -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan - -use std::{ - collections::{HashMap, HashSet}, - iter::once, -}; - -use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::{ - col, - expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, - logical_plan::{ - Aggregate, - CrossJoin, - Join, - JoinType, - Limit, - LogicalPlan, - Projection, - TableScan, - Union, - }, - utils::{expr_to_columns, exprlist_to_columns, from_plan}, - Expr, - TableProviderFilterPushDown, -}; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerRule}; - -/// Filter Push Down optimizer rule pushes filter clauses down the plan -/// # Introduction -/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)). -/// An example of a filter-commutative operation is a projection; a counter-example is `limit`. -/// -/// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B) -/// can commute with a filter that depends on A only, but does not commute with a filter that depends -/// on SUM(B). -/// -/// This optimizer commutes filters with filter-commutative operations to push the filters -/// the closest possible to the scans, re-writing the filter expressions by every -/// projection that changes the filter's expression. -/// -/// Filter: b Gt Int64(10) -/// Projection: a AS b -/// -/// is optimized to -/// -/// Projection: a AS b -/// Filter: a Gt Int64(10) <--- changed from b to a -/// -/// This performs a single pass through the plan. When it passes through a filter, it stores that filter, -/// and when it reaches a node that does not commute with it, it adds the filter to that place. -/// When it passes through a projection, it re-writes the filter's expression taking into account that projection. -/// When multiple filters would have been written, it `AND` their expressions into a single expression. -#[derive(Default)] -pub struct FilterPushDown {} - -/// Filter predicate represented by tuple of expression and its columns -type Predicate = (Expr, HashSet); - -/// Multiple filter predicates represented by tuple of expressions vector -/// and corresponding expression columns vector -type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); - -#[derive(Debug, Clone, Default)] -struct State { - // (predicate, columns on the predicate) - filters: Vec, -} - -impl State { - fn append_predicates(&mut self, predicates: Predicates) { - predicates - .0 - .into_iter() - .zip(predicates.1) - .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) - } -} - -/// returns all predicates in `state` that depend on any of `used_columns` -/// or the ones that does not reference any columns (e.g. WHERE 1=1) -fn get_predicates<'a>(state: &'a State, used_columns: &HashSet) -> Predicates<'a> { - state - .filters - .iter() - .filter(|(_, columns)| { - columns.is_empty() - || !columns - .intersection(used_columns) - .collect::>() - .is_empty() - }) - .map(|&(ref a, ref b)| (a, b)) - .unzip() -} - -/// Optimizes the plan -fn push_down(state: &State, plan: &LogicalPlan) -> Result { - let new_inputs = plan - .inputs() - .iter() - .map(|input| optimize(input, state.clone())) - .collect::>>()?; - - let expr = plan.expressions(); - from_plan(plan, &expr, &new_inputs) -} - -// remove all filters from `filters` that are in `predicate_columns` -fn remove_filters(filters: &[Predicate], predicate_columns: &[&HashSet]) -> Vec { - filters - .iter() - .filter(|(_, columns)| !predicate_columns.contains(&columns)) - .cloned() - .collect::>() -} - -/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters -/// in `state` depend on the columns `used_columns`. -fn issue_filters( - mut state: State, - used_columns: HashSet, - plan: &LogicalPlan, -) -> Result { - let (predicates, predicate_columns) = get_predicates(&state, &used_columns); - - if predicates.is_empty() { - // all filters can be pushed down => optimize inputs and return new plan - return push_down(&state, plan); - } - - let plan = utils::add_filter(plan.clone(), &predicates)?; - - state.filters = remove_filters(&state.filters, &predicate_columns); - - // continue optimization over all input nodes by cloning the current state (i.e. each node is independent) - push_down(&state, &plan) -} - -// For a given JOIN logical plan, determine whether each side of the join is preserved. -// We say a join side is preserved if the join returns all or a subset of the rows from -// the relevant side, such that each row of the output table directly maps to a row of -// the preserved input table. If a table is not preserved, it can provide extra null rows. -// That is, there may be rows in the output table that don't directly map to a row in the -// input table. -// -// For example: -// - In an inner join, both sides are preserved, because each row of the output -// maps directly to a row from each side. -// - In a left join, the left side is preserved and the right is not, because -// there may be rows in the output that don't directly map to a row in the -// right input (due to nulls filling where there is no match on the right). -// -// This is important because we can always push down post-join filters to a preserved -// side of the join, assuming the filter only references columns from that side. For the -// non-preserved side it can be more tricky. -// -// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - _ => todo!(), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => Err(DataFusionError::Internal( - "lr_is_preserved only valid for JOIN nodes".to_string(), - )), - } -} - -// For a given JOIN logical plan, determine whether each side of the join is preserved -// in terms on join filtering. -// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::LeftAnti => { - // filter_push_down does not yet support SEMI/ANTI joins with join conditions - Ok((false, false)) - } - _ => todo!(), - }, - LogicalPlan::CrossJoin(_) => Err(DataFusionError::Internal( - "on_lr_is_preserved cannot be applied to CROSSJOIN nodes".to_string(), - )), - _ => Err(DataFusionError::Internal( - "on_lr_is_preserved only valid for JOIN nodes".to_string(), - )), - } -} - -// Determine which predicates in state can be pushed down to a given side of a join. -// To determine this, we need to know the schema of the relevant join side and whether -// or not the side's rows are preserved when joining. If the side is not preserved, we -// do not push down anything. Otherwise we can push down predicates where all of the -// relevant columns are contained on the relevant join side's schema. -fn get_pushable_join_predicates<'a>( - filters: &'a [Predicate], - schema: &DFSchema, - preserved: bool, -) -> Predicates<'a> { - if !preserved { - return (vec![], vec![]); - } - - let schema_columns = schema - .fields() - .iter() - .flat_map(|f| { - [ - f.qualified_column(), - // we need to push down filter using unqualified column as well - f.unqualified_column(), - ] - }) - .collect::>(); - - filters - .iter() - .filter(|(_, columns)| { - let all_columns_in_schema = schema_columns - .intersection(columns) - .collect::>() - .len() - == columns.len(); - all_columns_in_schema - }) - .map(|(a, b)| (a, b)) - .unzip() -} - -fn optimize_join( - mut state: State, - plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, - on_filter: Vec, -) -> Result { - // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(plan)?; - let to_left = get_pushable_join_predicates(&state.filters, left.schema(), left_preserved); - let to_right = get_pushable_join_predicates(&state.filters, right.schema(), right_preserved); - let to_keep: Predicates = state - .filters - .iter() - .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e)) - .map(|(a, b)| (a, b)) - .unzip(); - - // Get pushable predicates from join filter - let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() { - ((vec![], vec![]), (vec![], vec![]), vec![]) - } else { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?; - let on_to_left = get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved); - let on_to_right = - get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved); - let on_to_keep = on_filter - .iter() - .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e)) - .map(|(a, _)| a.clone()) - .collect::>(); - - (on_to_left, on_to_right, on_to_keep) - }; - - // Build new filter states using pushable predicates - // from current optimizer states and from ON clause. - // Then recursively call optimization for both join inputs - let mut left_state = State { filters: vec![] }; - left_state.append_predicates(to_left); - left_state.append_predicates(on_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = State { filters: vec![] }; - right_state.append_predicates(to_right); - right_state.append_predicates(on_to_right); - let right = optimize(right, right_state)?; - - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let expr = plan.expressions(); - let expr = if !on_filter.is_empty() && on_to_keep.is_empty() { - // New filter expression is None - should remove last element - expr[..expr.len() - 1].to_vec() - } else if !on_to_keep.is_empty() { - // Replace last element with new filter expression - expr[..expr.len() - 1] - .iter() - .cloned() - .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap())) - .collect() - } else { - plan.expressions() - }; - let plan = from_plan(plan, &expr, &[left, right])?; - - if to_keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = utils::add_filter(plan, &to_keep.0); - state.filters = remove_filters(&state.filters, &to_keep.1); - plan - } -} - -fn optimize(plan: &LogicalPlan, mut state: State) -> Result { - match plan { - LogicalPlan::Explain { .. } => { - // push the optimization to the plan of this explain - push_down(&state, plan) - } - LogicalPlan::Analyze { .. } => push_down(&state, plan), - LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); - - predicates - .into_iter() - .try_for_each::<_, Result<()>>(|predicate| { - let mut columns: HashSet = HashSet::new(); - expr_to_columns(predicate, &mut columns)?; - state.filters.push((predicate.clone(), columns)); - Ok(()) - })?; - - optimize(filter.input(), state) - } - LogicalPlan::Projection(Projection { - input, - expr, - schema, - alias: _, - }) => { - // A projection is filter-commutable, but re-writes all predicate expressions - // collect projection. - let projection = schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &expr[i] { - Expr::Alias(expr, _) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>(); - - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; - - columns.clear(); - expr_to_columns(predicate, columns)?; - } - - // optimize inner - let new_input = optimize(input, state)?; - Ok(from_plan(plan, expr, &[new_input])?) - } - LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { - // An aggregate's aggreagate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut used_columns = HashSet::new(); - exprlist_to_columns(aggr_expr, &mut used_columns)?; - - let agg_columns = aggr_expr - .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) - .collect::>>()?; - used_columns.extend(agg_columns); - - issue_filters(state, used_columns, plan) - } - LogicalPlan::Sort { .. } => { - // sort is filter-commutable - push_down(&state, plan) - } - LogicalPlan::Union(Union { - inputs: _, - schema, - alias: _, - }) => { - // union changing all qualifiers while building logical plan so we need - // to rewrite filters to push unqualified columns to inputs - let projection = schema - .fields() - .iter() - .map(|field| (field.qualified_name(), col(field.name()))) - .collect::>(); - - // rewriting predicate expressions using unqualified names as replacements - if !projection.is_empty() { - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; - - columns.clear(); - expr_to_columns(predicate, columns)?; - } - } - - push_down(&state, plan) - } - LogicalPlan::Limit(Limit { input, .. }) => { - // limit is _not_ filter-commutable => collect all columns from its input - let used_columns = input - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - optimize_join(state, plan, left, right, vec![]) - } - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - .. - }) => { - // Convert JOIN ON predicate to Predicates - let on_filters = filter - .as_ref() - .map(|e| { - let predicates = utils::split_conjunction(e); - - predicates - .into_iter() - .map(|e| { - let mut accum = HashSet::new(); - expr_to_columns(e, &mut accum)?; - Ok((e.clone(), accum)) - }) - .collect::>>() - }) - .unwrap_or_else(|| Ok(vec![]))?; - - if *join_type == JoinType::Inner { - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - let join_side_filters = state - .filters - .iter() - .chain(on_filters.iter()) - .filter_map(|(predicate, columns)| { - let mut join_cols_to_replace = HashMap::new(); - for col in columns.iter() { - for (l, r) in on { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } - - if join_cols_to_replace.is_empty() { - return None; - } - - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; - - let join_side_columns = columns - .clone() - .into_iter() - // replace keys in join_cols_to_replace with values in resulting column - // set - .filter(|c| !join_cols_to_replace.contains_key(c)) - .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) - .collect(); - - Some(Ok((join_side_predicate, join_side_columns))) - }) - .collect::>>()?; - state.filters.extend(join_side_filters); - } - - optimize_join(state, plan, left, right, on_filters) - } - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - filters, - projection, - table_name, - fetch, - }) => { - let mut used_columns = HashSet::new(); - let mut new_filters = filters.clone(); - - for (filter_expr, cols) in &state.filters { - let (preserve_filter_node, add_to_provider) = - match source.supports_filter_pushdown(filter_expr)? { - TableProviderFilterPushDown::Unsupported => (true, false), - TableProviderFilterPushDown::Inexact => (true, true), - TableProviderFilterPushDown::Exact => (false, true), - }; - - if preserve_filter_node { - used_columns.extend(cols.clone()); - } - - if add_to_provider { - // Don't add expression again if it's already present in - // pushed down filters. - if new_filters.contains(filter_expr) { - continue; - } - new_filters.push(filter_expr.clone()); - } - } - - issue_filters( - state, - used_columns, - &LogicalPlan::TableScan(TableScan { - source: source.clone(), - projection: projection.clone(), - projected_schema: projected_schema.clone(), - table_name: table_name.clone(), - filters: new_filters, - fetch: *fetch, - }), - ) - } - _ => { - // all other plans are _not_ filter-commutable - let used_columns = plan - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - } -} - -impl OptimizerRule for FilterPushDown { - fn name(&self) -> &str { - "filter_push_down" - } - - fn optimize(&self, plan: &LogicalPlan, _: &mut OptimizerConfig) -> Result { - optimize(plan, State::default()) - } -} - -impl FilterPushDown { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -/// replaces columns by its name on the projection. -fn replace_cols_by_name(e: Expr, replace_map: &HashMap) -> Result { - struct ColumnReplacer<'a> { - replace_map: &'a HashMap, - } - - impl<'a> ExprRewriter for ColumnReplacer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = &expr { - match self.replace_map.get(&c.flat_name()) { - Some(new_c) => Ok(new_c.clone()), - None => Ok(expr), - } - } else { - Ok(expr) - } - } - } - - e.rewrite(&mut ColumnReplacer { replace_map }) -} diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 7a3a8c870..204f1588e 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -1,4 +1,6 @@ import logging +import operator +from functools import reduce from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer @@ -76,14 +78,15 @@ def _apply_filters(self, table_scan, rel, dc, context): cc = dc.column_container filters = table_scan.getFilters() # All partial filters here are applied in conjunction (&) - df_condition = None - for filter in filters: - filter_condition = RexConverter.convert(rel, filter, dc, context=context) - df_condition = ( - filter_condition - if df_condition is None - else (df_condition & filter_condition) + if filters: + df_condition = reduce( + operator.and_, + [ + RexConverter.convert(rel, rex, dc, context=context) + for rex in filters + ], ) + if len(filters) > 0: df = filter_or_scalar(df, df_condition) From 6df4bd09f48cb61684a54f27f405bc8f7d7e9246 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 28 Nov 2022 07:01:22 -0800 Subject: [PATCH 07/15] use filter_pushdown_rule from datafusion --- dask_planner/src/sql/optimizer.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs index 8067e8a5e..2f2843763 100644 --- a/dask_planner/src/sql/optimizer.rs +++ b/dask_planner/src/sql/optimizer.rs @@ -10,6 +10,7 @@ use datafusion_optimizer::{ // eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit, filter_null_join_keys::FilterNullJoinKeys, + filter_push_down::FilterPushDown, inline_table_scan::InlineTableScan, limit_push_down::LimitPushDown, optimizer::{Optimizer, OptimizerRule}, @@ -29,9 +30,6 @@ use log::trace; mod eliminate_agg_distinct; use eliminate_agg_distinct::EliminateAggDistinct; -mod filter_push_down; -use filter_push_down::FilterPushDown; - /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance pub struct DaskSqlOptimizer { From 1964368069552eacbce81d0179f6e45e46fa14ca Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 28 Nov 2022 08:44:48 -0800 Subject: [PATCH 08/15] Update predicate pushdown tests --- tests/integration/test_filter.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 69b964514..b2e7ceb69 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -162,10 +162,10 @@ def test_filter_year(c): pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 6)", lambda x: x[(x["b"] == 1) | (x["b"] == 6)], - [[("b", "<=", 1), ("b", ">=", 1)], [("b", "<=", 6), ("b", ">=", 6)]], - marks=pytest.mark.xfail( - reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" - ), + [[("b", "==", 1)], [("b", "==", 6)]], + # marks=pytest.mark.xfail( + # reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" + # ), ), ( "SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", @@ -206,6 +206,7 @@ def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): if expect_filters: got_filters = frozenset(frozenset(v) for v in got_filters) expect_filters = frozenset(frozenset(v) for v in filters) + assert got_filters == expect_filters # Check computed result is correct From 49e90d6555af9341c9ca041f07142029626a7e3c Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 28 Nov 2022 11:31:31 -0800 Subject: [PATCH 09/15] Update predicate pushdown tests --- tests/integration/test_filter.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index b2e7ceb69..db702ad79 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -157,20 +157,35 @@ def test_filter_year(c): ( "SELECT * FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)], - [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + [ + [("a", "==", 1), ("b", "<", 10)], + [("a", "==", 1), ("b", ">", 5)], + [("b", ">", 5), ("b", "<", 10)], + [("a", "==", 1)], + ], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 6)", lambda x: x[(x["b"] == 1) | (x["b"] == 6)], [[("b", "==", 1)], [("b", "==", 6)]], - # marks=pytest.mark.xfail( - # reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" - # ), + ), + pytest.param( + "SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)", + lambda x: x[(x["b"] == 1) | (x["b"] == 3) | (x["b"] == 5) | (x["b"] == 6)], + [[("b", "==", 1)], [("b", "==", 3)], [("b", "==", 5)], [("b", "==", 6)]], + marks=pytest.mark.xfail( + reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" + ), ), ( "SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)][["a"]], - [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + [ + [("a", "==", 1), ("b", "<", 10)], + [("a", "==", 1), ("b", ">", 5)], + [("b", ">", 5), ("b", "<", 10)], + [("a", "==", 1)], + ], ), ( # Original filters NOT in disjunctive normal form From 294bda1a2ccc530cb1417f87e0d5acc919323671 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Tue, 29 Nov 2022 04:24:54 -0800 Subject: [PATCH 10/15] unxfail q21 --- tests/unit/test_queries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index f35bd5750..b23dab811 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -11,7 +11,6 @@ 14, 16, 18, - 21, 22, 23, 24, From cfe29057fb3286f72d7e5382c5ac8ec61ef299f6 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 30 Nov 2022 06:28:46 -0800 Subject: [PATCH 11/15] Update DaskTableSource filterPushDown comments --- dask_planner/src/sql/table.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 76c24af57..b4d060b39 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -53,13 +53,11 @@ impl TableSource for DaskTableSource { ) -> datafusion_common::Result { let filters = split_conjunction(filter); if filters.iter().all(|f| is_supported_push_down_expr(f)) { - // TODO this should return Exact but we cannot make that change until we - // are actually pushing the TableScan filters down to the reader because - // returning Exact here would remove the Filter from the plan + // Push down filters to the tablescan operation if all are supported Ok(TableProviderFilterPushDown::Exact) } else if filters.iter().any(|f| is_supported_push_down_expr(f)) { - // we can partially apply the filter in the TableScan but we need - // to retain the Filter operator in the plan as well + // Partially apply the filter in the TableScan but retain + // the Filter operator in the plan as well Ok(TableProviderFilterPushDown::Inexact) } else { Ok(TableProviderFilterPushDown::Unsupported) From 3b85ac45d453d1226fd4c617823c3b2a2a5d2a59 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 30 Nov 2022 08:21:11 -0800 Subject: [PATCH 12/15] Reenable clippy check for supports_filter_pushdown --- dask_planner/src/sql/table.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index b4d060b39..679559319 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -45,8 +45,6 @@ impl TableSource for DaskTableSource { self.schema.clone() } - // temporarily disable clippy until TODO comment below is addressed - #[allow(clippy::if_same_then_else)] fn supports_filter_pushdown( &self, filter: &Expr, From 521129c1db7fc6452f93b066fa9b5d66e1472ce5 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 30 Nov 2022 08:48:26 -0800 Subject: [PATCH 13/15] Simplify apply_filter conditional check --- dask_sql/physical/rel/logical/table_scan.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 204f1588e..8bd7874f2 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -86,8 +86,6 @@ def _apply_filters(self, table_scan, rel, dc, context): for rex in filters ], ) - - if len(filters) > 0: df = filter_or_scalar(df, df_condition) return DataContainer(df, cc) From 47eee60b80fe32f4724bd11e7bebb62ea4e55a24 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 1 Dec 2022 07:09:40 -0800 Subject: [PATCH 14/15] Un-xfail q40 --- tests/unit/test_queries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index 4f1602a50..becb378a5 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -19,7 +19,6 @@ 35, 36, 39, - 40, 41, 44, 45, From 2090f8609bd330241de0b3b83509ab3bd73b80ed Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Thu, 1 Dec 2022 08:01:03 -0800 Subject: [PATCH 15/15] Rerun tests