diff --git a/dask_planner/src/sql/logical/table_scan.rs b/dask_planner/src/sql/logical/table_scan.rs index db0fbf599..537f011cc 100644 --- a/dask_planner/src/sql/logical/table_scan.rs +++ b/dask_planner/src/sql/logical/table_scan.rs @@ -1,12 +1,19 @@ -use datafusion_expr::logical_plan::TableScan; +use std::sync::Arc; + +use datafusion_common::DFSchema; +use datafusion_expr::{logical_plan::TableScan, LogicalPlan}; use pyo3::prelude::*; -use crate::sql::{exceptions::py_type_err, logical}; +use crate::{ + expression::{py_expr_list, PyExpr}, + sql::exceptions::py_type_err, +}; #[pyclass(name = "TableScan", module = "dask_planner", subclass)] #[derive(Clone)] pub struct PyTableScan { pub(crate) table_scan: TableScan, + input: Arc, } #[pymethods] @@ -31,14 +38,32 @@ impl PyTableScan { fn contains_projections(&self) -> bool { self.table_scan.projection.is_some() } + + #[pyo3(name = "getFilters")] + fn scan_filters(&self) -> PyResult> { + py_expr_list(&self.input, &self.table_scan.filters) + } } -impl TryFrom for PyTableScan { +impl TryFrom for PyTableScan { type Error = PyErr; - fn try_from(logical_plan: logical::LogicalPlan) -> Result { + fn try_from(logical_plan: LogicalPlan) -> Result { match logical_plan { - logical::LogicalPlan::TableScan(table_scan) => Ok(PyTableScan { table_scan }), + LogicalPlan::TableScan(table_scan) => { + // Create an input logical plan that's identical to the table scan with schema from the table source + let mut input = table_scan.clone(); + input.projected_schema = DFSchema::try_from_qualified_schema( + &table_scan.table_name, + &table_scan.source.schema(), + ) + .map_or(input.projected_schema, Arc::new); + + Ok(PyTableScan { + table_scan, + input: Arc::new(LogicalPlan::TableScan(input)), + }) + } _ => Err(py_type_err("unexpected plan")), } } diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs index 8067e8a5e..2f2843763 100644 --- a/dask_planner/src/sql/optimizer.rs +++ b/dask_planner/src/sql/optimizer.rs @@ -10,6 +10,7 @@ use datafusion_optimizer::{ // eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit, filter_null_join_keys::FilterNullJoinKeys, + filter_push_down::FilterPushDown, inline_table_scan::InlineTableScan, limit_push_down::LimitPushDown, optimizer::{Optimizer, OptimizerRule}, @@ -29,9 +30,6 @@ use log::trace; mod eliminate_agg_distinct; use eliminate_agg_distinct::EliminateAggDistinct; -mod filter_push_down; -use filter_push_down::FilterPushDown; - /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance pub struct DaskSqlOptimizer { diff --git a/dask_planner/src/sql/optimizer/filter_push_down.rs b/dask_planner/src/sql/optimizer/filter_push_down.rs deleted file mode 100644 index ac0429774..000000000 --- a/dask_planner/src/sql/optimizer/filter_push_down.rs +++ /dev/null @@ -1,641 +0,0 @@ -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan - -use std::{ - collections::{HashMap, HashSet}, - iter::once, -}; - -use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::{ - col, - expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, - logical_plan::{ - Aggregate, - CrossJoin, - Join, - JoinType, - Limit, - LogicalPlan, - Projection, - TableScan, - Union, - }, - utils::{expr_to_columns, exprlist_to_columns, from_plan}, - Expr, - TableProviderFilterPushDown, -}; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerRule}; - -/// Filter Push Down optimizer rule pushes filter clauses down the plan -/// # Introduction -/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)). -/// An example of a filter-commutative operation is a projection; a counter-example is `limit`. -/// -/// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B) -/// can commute with a filter that depends on A only, but does not commute with a filter that depends -/// on SUM(B). -/// -/// This optimizer commutes filters with filter-commutative operations to push the filters -/// the closest possible to the scans, re-writing the filter expressions by every -/// projection that changes the filter's expression. -/// -/// Filter: b Gt Int64(10) -/// Projection: a AS b -/// -/// is optimized to -/// -/// Projection: a AS b -/// Filter: a Gt Int64(10) <--- changed from b to a -/// -/// This performs a single pass through the plan. When it passes through a filter, it stores that filter, -/// and when it reaches a node that does not commute with it, it adds the filter to that place. -/// When it passes through a projection, it re-writes the filter's expression taking into account that projection. -/// When multiple filters would have been written, it `AND` their expressions into a single expression. -#[derive(Default)] -pub struct FilterPushDown {} - -/// Filter predicate represented by tuple of expression and its columns -type Predicate = (Expr, HashSet); - -/// Multiple filter predicates represented by tuple of expressions vector -/// and corresponding expression columns vector -type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); - -#[derive(Debug, Clone, Default)] -struct State { - // (predicate, columns on the predicate) - filters: Vec, -} - -impl State { - fn append_predicates(&mut self, predicates: Predicates) { - predicates - .0 - .into_iter() - .zip(predicates.1) - .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) - } -} - -/// returns all predicates in `state` that depend on any of `used_columns` -/// or the ones that does not reference any columns (e.g. WHERE 1=1) -fn get_predicates<'a>(state: &'a State, used_columns: &HashSet) -> Predicates<'a> { - state - .filters - .iter() - .filter(|(_, columns)| { - columns.is_empty() - || !columns - .intersection(used_columns) - .collect::>() - .is_empty() - }) - .map(|&(ref a, ref b)| (a, b)) - .unzip() -} - -/// Optimizes the plan -fn push_down(state: &State, plan: &LogicalPlan) -> Result { - let new_inputs = plan - .inputs() - .iter() - .map(|input| optimize(input, state.clone())) - .collect::>>()?; - - let expr = plan.expressions(); - from_plan(plan, &expr, &new_inputs) -} - -// remove all filters from `filters` that are in `predicate_columns` -fn remove_filters(filters: &[Predicate], predicate_columns: &[&HashSet]) -> Vec { - filters - .iter() - .filter(|(_, columns)| !predicate_columns.contains(&columns)) - .cloned() - .collect::>() -} - -/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters -/// in `state` depend on the columns `used_columns`. -fn issue_filters( - mut state: State, - used_columns: HashSet, - plan: &LogicalPlan, -) -> Result { - let (predicates, predicate_columns) = get_predicates(&state, &used_columns); - - if predicates.is_empty() { - // all filters can be pushed down => optimize inputs and return new plan - return push_down(&state, plan); - } - - let plan = utils::add_filter(plan.clone(), &predicates)?; - - state.filters = remove_filters(&state.filters, &predicate_columns); - - // continue optimization over all input nodes by cloning the current state (i.e. each node is independent) - push_down(&state, &plan) -} - -// For a given JOIN logical plan, determine whether each side of the join is preserved. -// We say a join side is preserved if the join returns all or a subset of the rows from -// the relevant side, such that each row of the output table directly maps to a row of -// the preserved input table. If a table is not preserved, it can provide extra null rows. -// That is, there may be rows in the output table that don't directly map to a row in the -// input table. -// -// For example: -// - In an inner join, both sides are preserved, because each row of the output -// maps directly to a row from each side. -// - In a left join, the left side is preserved and the right is not, because -// there may be rows in the output that don't directly map to a row in the -// right input (due to nulls filling where there is no match on the right). -// -// This is important because we can always push down post-join filters to a preserved -// side of the join, assuming the filter only references columns from that side. For the -// non-preserved side it can be more tricky. -// -// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - _ => todo!(), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => Err(DataFusionError::Internal( - "lr_is_preserved only valid for JOIN nodes".to_string(), - )), - } -} - -// For a given JOIN logical plan, determine whether each side of the join is preserved -// in terms on join filtering. -// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::LeftAnti => { - // filter_push_down does not yet support SEMI/ANTI joins with join conditions - Ok((false, false)) - } - _ => todo!(), - }, - LogicalPlan::CrossJoin(_) => Err(DataFusionError::Internal( - "on_lr_is_preserved cannot be applied to CROSSJOIN nodes".to_string(), - )), - _ => Err(DataFusionError::Internal( - "on_lr_is_preserved only valid for JOIN nodes".to_string(), - )), - } -} - -// Determine which predicates in state can be pushed down to a given side of a join. -// To determine this, we need to know the schema of the relevant join side and whether -// or not the side's rows are preserved when joining. If the side is not preserved, we -// do not push down anything. Otherwise we can push down predicates where all of the -// relevant columns are contained on the relevant join side's schema. -fn get_pushable_join_predicates<'a>( - filters: &'a [Predicate], - schema: &DFSchema, - preserved: bool, -) -> Predicates<'a> { - if !preserved { - return (vec![], vec![]); - } - - let schema_columns = schema - .fields() - .iter() - .flat_map(|f| { - [ - f.qualified_column(), - // we need to push down filter using unqualified column as well - f.unqualified_column(), - ] - }) - .collect::>(); - - filters - .iter() - .filter(|(_, columns)| { - let all_columns_in_schema = schema_columns - .intersection(columns) - .collect::>() - .len() - == columns.len(); - all_columns_in_schema - }) - .map(|(a, b)| (a, b)) - .unzip() -} - -fn optimize_join( - mut state: State, - plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, - on_filter: Vec, -) -> Result { - // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(plan)?; - let to_left = get_pushable_join_predicates(&state.filters, left.schema(), left_preserved); - let to_right = get_pushable_join_predicates(&state.filters, right.schema(), right_preserved); - let to_keep: Predicates = state - .filters - .iter() - .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e)) - .map(|(a, b)| (a, b)) - .unzip(); - - // Get pushable predicates from join filter - let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() { - ((vec![], vec![]), (vec![], vec![]), vec![]) - } else { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?; - let on_to_left = get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved); - let on_to_right = - get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved); - let on_to_keep = on_filter - .iter() - .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e)) - .map(|(a, _)| a.clone()) - .collect::>(); - - (on_to_left, on_to_right, on_to_keep) - }; - - // Build new filter states using pushable predicates - // from current optimizer states and from ON clause. - // Then recursively call optimization for both join inputs - let mut left_state = State { filters: vec![] }; - left_state.append_predicates(to_left); - left_state.append_predicates(on_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = State { filters: vec![] }; - right_state.append_predicates(to_right); - right_state.append_predicates(on_to_right); - let right = optimize(right, right_state)?; - - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let expr = plan.expressions(); - let expr = if !on_filter.is_empty() && on_to_keep.is_empty() { - // New filter expression is None - should remove last element - expr[..expr.len() - 1].to_vec() - } else if !on_to_keep.is_empty() { - // Replace last element with new filter expression - expr[..expr.len() - 1] - .iter() - .cloned() - .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap())) - .collect() - } else { - plan.expressions() - }; - let plan = from_plan(plan, &expr, &[left, right])?; - - if to_keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = utils::add_filter(plan, &to_keep.0); - state.filters = remove_filters(&state.filters, &to_keep.1); - plan - } -} - -fn optimize(plan: &LogicalPlan, mut state: State) -> Result { - match plan { - LogicalPlan::Explain { .. } => { - // push the optimization to the plan of this explain - push_down(&state, plan) - } - LogicalPlan::Analyze { .. } => push_down(&state, plan), - LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); - - predicates - .into_iter() - .try_for_each::<_, Result<()>>(|predicate| { - let mut columns: HashSet = HashSet::new(); - expr_to_columns(predicate, &mut columns)?; - state.filters.push((predicate.clone(), columns)); - Ok(()) - })?; - - optimize(filter.input(), state) - } - LogicalPlan::Projection(Projection { - input, - expr, - schema, - alias: _, - }) => { - // A projection is filter-commutable, but re-writes all predicate expressions - // collect projection. - let projection = schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &expr[i] { - Expr::Alias(expr, _) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>(); - - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; - - columns.clear(); - expr_to_columns(predicate, columns)?; - } - - // optimize inner - let new_input = optimize(input, state)?; - Ok(from_plan(plan, expr, &[new_input])?) - } - LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { - // An aggregate's aggreagate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut used_columns = HashSet::new(); - exprlist_to_columns(aggr_expr, &mut used_columns)?; - - let agg_columns = aggr_expr - .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) - .collect::>>()?; - used_columns.extend(agg_columns); - - issue_filters(state, used_columns, plan) - } - LogicalPlan::Sort { .. } => { - // sort is filter-commutable - push_down(&state, plan) - } - LogicalPlan::Union(Union { - inputs: _, - schema, - alias: _, - }) => { - // union changing all qualifiers while building logical plan so we need - // to rewrite filters to push unqualified columns to inputs - let projection = schema - .fields() - .iter() - .map(|field| (field.qualified_name(), col(field.name()))) - .collect::>(); - - // rewriting predicate expressions using unqualified names as replacements - if !projection.is_empty() { - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; - - columns.clear(); - expr_to_columns(predicate, columns)?; - } - } - - push_down(&state, plan) - } - LogicalPlan::Limit(Limit { input, .. }) => { - // limit is _not_ filter-commutable => collect all columns from its input - let used_columns = input - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - optimize_join(state, plan, left, right, vec![]) - } - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - .. - }) => { - // Convert JOIN ON predicate to Predicates - let on_filters = filter - .as_ref() - .map(|e| { - let predicates = utils::split_conjunction(e); - - predicates - .into_iter() - .map(|e| { - let mut accum = HashSet::new(); - expr_to_columns(e, &mut accum)?; - Ok((e.clone(), accum)) - }) - .collect::>>() - }) - .unwrap_or_else(|| Ok(vec![]))?; - - if *join_type == JoinType::Inner { - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - let join_side_filters = state - .filters - .iter() - .chain(on_filters.iter()) - .filter_map(|(predicate, columns)| { - let mut join_cols_to_replace = HashMap::new(); - for col in columns.iter() { - for (l, r) in on { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } - - if join_cols_to_replace.is_empty() { - return None; - } - - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; - - let join_side_columns = columns - .clone() - .into_iter() - // replace keys in join_cols_to_replace with values in resulting column - // set - .filter(|c| !join_cols_to_replace.contains_key(c)) - .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) - .collect(); - - Some(Ok((join_side_predicate, join_side_columns))) - }) - .collect::>>()?; - state.filters.extend(join_side_filters); - } - - optimize_join(state, plan, left, right, on_filters) - } - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - filters, - projection, - table_name, - fetch, - }) => { - let mut used_columns = HashSet::new(); - let mut new_filters = filters.clone(); - - for (filter_expr, cols) in &state.filters { - let (preserve_filter_node, add_to_provider) = - match source.supports_filter_pushdown(filter_expr)? { - TableProviderFilterPushDown::Unsupported => (true, false), - TableProviderFilterPushDown::Inexact => (true, true), - TableProviderFilterPushDown::Exact => (false, true), - }; - - if preserve_filter_node { - used_columns.extend(cols.clone()); - } - - if add_to_provider { - // Don't add expression again if it's already present in - // pushed down filters. - if new_filters.contains(filter_expr) { - continue; - } - new_filters.push(filter_expr.clone()); - } - } - - issue_filters( - state, - used_columns, - &LogicalPlan::TableScan(TableScan { - source: source.clone(), - projection: projection.clone(), - projected_schema: projected_schema.clone(), - table_name: table_name.clone(), - filters: new_filters, - fetch: *fetch, - }), - ) - } - _ => { - // all other plans are _not_ filter-commutable - let used_columns = plan - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - } -} - -impl OptimizerRule for FilterPushDown { - fn name(&self) -> &str { - "filter_push_down" - } - - fn optimize(&self, plan: &LogicalPlan, _: &mut OptimizerConfig) -> Result { - optimize(plan, State::default()) - } -} - -impl FilterPushDown { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -/// replaces columns by its name on the projection. -fn replace_cols_by_name(e: Expr, replace_map: &HashMap) -> Result { - struct ColumnReplacer<'a> { - replace_map: &'a HashMap, - } - - impl<'a> ExprRewriter for ColumnReplacer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = &expr { - match self.replace_map.get(&c.flat_name()) { - Some(new_c) => Ok(new_c.clone()), - None => Ok(expr), - } - } else { - Ok(expr) - } - } - } - - e.rewrite(&mut ColumnReplacer { replace_map }) -} diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index c44eec38c..679559319 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -45,21 +45,17 @@ impl TableSource for DaskTableSource { self.schema.clone() } - // temporarily disable clippy until TODO comment below is addressed - #[allow(clippy::if_same_then_else)] fn supports_filter_pushdown( &self, filter: &Expr, ) -> datafusion_common::Result { let filters = split_conjunction(filter); if filters.iter().all(|f| is_supported_push_down_expr(f)) { - // TODO this should return Exact but we cannot make that change until we - // are actually pushing the TableScan filters down to the reader because - // returning Exact here would remove the Filter from the plan - Ok(TableProviderFilterPushDown::Inexact) + // Push down filters to the tablescan operation if all are supported + Ok(TableProviderFilterPushDown::Exact) } else if filters.iter().any(|f| is_supported_push_down_expr(f)) { - // we can partially apply the filter in the TableScan but we need - // to retain the Filter operator in the plan as well + // Partially apply the filter in the TableScan but retain + // the Filter operator in the plan as well Ok(TableProviderFilterPushDown::Inexact) } else { Ok(TableProviderFilterPushDown::Unsupported) @@ -67,12 +63,9 @@ impl TableSource for DaskTableSource { } } -fn is_supported_push_down_expr(expr: &Expr) -> bool { - match expr { - // for now, we just attempt to push down simple IS NOT NULL filters on columns - Expr::IsNotNull(ref a) => matches!(a.as_ref(), Expr::Column(_)), - _ => false, - } +fn is_supported_push_down_expr(_expr: &Expr) -> bool { + // For now we support all kinds of expr's at this level + true } #[pyclass(name = "DaskStatistics", module = "dask_planner", subclass)] diff --git a/dask_sql/physical/rel/logical/table_scan.py b/dask_sql/physical/rel/logical/table_scan.py index 716e51dcd..8bd7874f2 100644 --- a/dask_sql/physical/rel/logical/table_scan.py +++ b/dask_sql/physical/rel/logical/table_scan.py @@ -1,8 +1,12 @@ import logging +import operator +from functools import reduce from typing import TYPE_CHECKING from dask_sql.datacontainer import DataContainer from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.physical.rel.logical.filter import filter_or_scalar +from dask_sql.physical.rex import RexConverter if TYPE_CHECKING: import dask_sql @@ -13,14 +17,13 @@ class DaskTableScanPlugin(BaseRelPlugin): """ - A DaskTableScal is the main ingredient: it will get the data + A DaskTableScan is the main ingredient: it will get the data from the database. It is always used, when the SQL looks like SELECT .... FROM table .... We need to get the dask dataframe from the registered tables and return the requested columns from it. - Calcite will always refer to columns via index. """ class_name = "TableScan" @@ -41,12 +44,23 @@ def convert( schema_name, table_name = [n.lower() for n in context.fqn(dask_table)] dc = context.schema[schema_name].tables[table_name] - df = dc.df + + # Apply filter before projections since filter columns may not be in projections + dc = self._apply_filters(table_scan, rel, dc, context) + dc = self._apply_projections(table_scan, dask_table, dc) + cc = dc.column_container + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(dc.df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc + def _apply_projections(self, table_scan, dask_table, dc): # If the 'TableScan' instance contains projected columns only retrieve those columns # otherwise get all projected columns from the 'Projection' instance, which is contained # in the 'RelDataType' instance, aka 'row_type' + df = dc.df + cc = dc.column_container if table_scan.containsProjections(): field_specifications = ( table_scan.getTableScanProjects() @@ -56,9 +70,22 @@ def convert( field_specifications = [ str(f) for f in dask_table.getRowType().getFieldNames() ] - cc = cc.limit_to(field_specifications) - cc = self.fix_column_to_row_type(cc, rel.getRowType()) - dc = DataContainer(df, cc) - dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) - return dc + return DataContainer(df, cc) + + def _apply_filters(self, table_scan, rel, dc, context): + df = dc.df + cc = dc.column_container + filters = table_scan.getFilters() + # All partial filters here are applied in conjunction (&) + if filters: + df_condition = reduce( + operator.and_, + [ + RexConverter.convert(rel, rex, dc, context=context) + for rex in filters + ], + ) + df = filter_or_scalar(df, df_condition) + + return DataContainer(df, cc) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 69b964514..db702ad79 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -157,12 +157,22 @@ def test_filter_year(c): ( "SELECT * FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)], - [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + [ + [("a", "==", 1), ("b", "<", 10)], + [("a", "==", 1), ("b", ">", 5)], + [("b", ">", 5), ("b", "<", 10)], + [("a", "==", 1)], + ], ), pytest.param( "SELECT * FROM parquet_ddf WHERE b IN (1, 6)", lambda x: x[(x["b"] == 1) | (x["b"] == 6)], - [[("b", "<=", 1), ("b", ">=", 1)], [("b", "<=", 6), ("b", ">=", 6)]], + [[("b", "==", 1)], [("b", "==", 6)]], + ), + pytest.param( + "SELECT * FROM parquet_ddf WHERE b IN (1, 3, 5, 6)", + lambda x: x[(x["b"] == 1) | (x["b"] == 3) | (x["b"] == 5) | (x["b"] == 6)], + [[("b", "==", 1)], [("b", "==", 3)], [("b", "==", 5)], [("b", "==", 6)]], marks=pytest.mark.xfail( reason="WIP https://github.com/dask-contrib/dask-sql/issues/607" ), @@ -170,7 +180,12 @@ def test_filter_year(c): ( "SELECT a FROM parquet_ddf WHERE (b > 5 AND b < 10) OR a = 1", lambda x: x[((x["b"] > 5) & (x["b"] < 10)) | (x["a"] == 1)][["a"]], - [[("a", "==", 1)], [("b", "<", 10), ("b", ">", 5)]], + [ + [("a", "==", 1), ("b", "<", 10)], + [("a", "==", 1), ("b", ">", 5)], + [("b", ">", 5), ("b", "<", 10)], + [("a", "==", 1)], + ], ), ( # Original filters NOT in disjunctive normal form @@ -206,6 +221,7 @@ def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters): if expect_filters: got_filters = frozenset(frozenset(v) for v in got_filters) expect_filters = frozenset(frozenset(v) for v in filters) + assert got_filters == expect_filters # Check computed result is correct diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index 15c044d95..becb378a5 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -11,7 +11,6 @@ 14, 16, 18, - 21, 22, 23, 24, @@ -20,7 +19,6 @@ 35, 36, 39, - 40, 41, 44, 45,