Skip to content

Commit

Permalink
Enable DataFusion CBO and introduce DaskSqlOptimizer (#558)
Browse files Browse the repository at this point in the history
* Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral

* Updates for test_filter

* more of test_filter.py working with the exception of some date pytests

* Add workflow to keep datafusion dev branch up to date (#440)

* Include setuptools-rust in conda build recipie, in host and run

* Remove PyArrow dependency

* rebase with datafusion-sql-planner

* refactor changes that were inadvertent during rebase

* timestamp with loglca time zone

* Bump DataFusion version (#494)

* bump DataFusion version

* remove unnecessary downcasts and use separate structs for TableSource and TableProvider

* Include RelDataType work

* Include RelDataType work

* Introduced SqlTypeName Enum in Rust and mappings for Python

* impl PyExpr.getIndex()

* add getRowType() for logical.rs

* Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes

* use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict

* linter changes, why did that work on my local pre-commit??

* linter changes, why did that work on my local pre-commit??

* Convert final strs to SqlTypeName Enum

* removed a few print statements

* commit to share with colleague

* updates

* checkpoint

* Temporarily disable conda run_test.py script since it uses features not yet implemented

* formatting after upstream merge

* expose fromString method for SqlTypeName to use Enums instead of strings for type checking

* expanded SqlTypeName from_string() support

* accept INT as INTEGER

* tests update

* checkpoint

* checkpoint

* Refactor PyExpr by removing From trait, and using recursion to expand expression list for rex calls

* skip test that uses create statement for gpuci

* Basic DataFusion Select Functionality (#489)

* Condition for BinaryExpr, filter, input_ref, rexcall, and rexliteral

* Updates for test_filter

* more of test_filter.py working with the exception of some date pytests

* Add workflow to keep datafusion dev branch up to date (#440)

* Include setuptools-rust in conda build recipie, in host and run

* Remove PyArrow dependency

* rebase with datafusion-sql-planner

* refactor changes that were inadvertent during rebase

* timestamp with loglca time zone

* Include RelDataType work

* Include RelDataType work

* Introduced SqlTypeName Enum in Rust and mappings for Python

* impl PyExpr.getIndex()

* add getRowType() for logical.rs

* Introduce DaskTypeMap for storing correlating SqlTypeName and DataTypes

* use str values instead of Rust Enums, Python is unable to Hash the Rust Enums if used in a dict

* linter changes, why did that work on my local pre-commit??

* linter changes, why did that work on my local pre-commit??

* Convert final strs to SqlTypeName Enum

* removed a few print statements

* Temporarily disable conda run_test.py script since it uses features not yet implemented

* expose fromString method for SqlTypeName to use Enums instead of strings for type checking

* expanded SqlTypeName from_string() support

* accept INT as INTEGER

* Remove print statements

* Default to UTC if tz is None

* Delegate timezone handling to the arrow library

* Updates from review

Co-authored-by: Charles Blackmon-Luca <[email protected]>

* updates for expression

* uncommented pytests

* uncommented pytests

* code cleanup for review

* code cleanup for review

* Enabled more pytest that work now

* Enabled more pytest that work now

* Output Expression as String when BinaryExpr does not contain a named alias

* Output Expression as String when BinaryExpr does not contain a named alias

* Disable 2 pytest that are causing gpuCI issues. They will be address in a follow up PR

* Handle Between operation for case-when

* adjust timestamp casting

* Refactor projection _column_name() logic to the _column_name logic in expression.rs

* removed println! statements

* introduce join getCondition() logic for retrieving the combining Rex logic for joining

* Updates from review

* Add Offset and point to repo with offset in datafusion

* Introduce offset

* limit updates

* commit before upstream merge

* Code formatting

* update Cargo.toml to use Arrow-DataFusion version with LIMIT logic

* Bump DataFusion version to get changes around variant_name()

* Use map partitions for determining the offset

* Merge with upstream

* Rename underlying DataContainer's DataFrame instance to match the column container names

* Adjust ColumnContainer mapping after join.py logic to entire the bakend mapping is reset

* Add enumerate to column_{i} generation string to ensure columns exist in both dataframes

* Adjust join schema logic to perform merge instead of join on rust side to avoid name collisions

* Handle DataFusion COUNT(UInt8(1)) as COUNT(*)

* commit before merge

* Update function for gathering index of a expression

* Update for review check

* Adjust RelDataType to retrieve fully qualified column names

* Adjust base.py to get fully qualified column  name

* Enable passing pytests in test_join.py

* Adjust keys provided by getting backend column mapping name

* Adjust output_col to not use the backend_column name for special reserved exprs

* uncomment cross join pytest which works now

* Uncomment passing pytests in test_select.py

* Review updates

* Add back complex join case condition, not just cross join but 'complex' joins

* Enable DataFusion CBO logic

* Disable EliminateFilter optimization rule

* updates

* Disable tests that hit CBO generated plan edge cases of yet to be implemented logic

* [REVIEW] - Modifiy sql.skip_optimize to use dask_config.get and remove used method parameter

* [REVIEW] - change name of configuration from skip_optimize to optimize

* [REVIEW] - Add OptimizeException catch and raise statements back

* Found issue where backend column names which are results of a single aggregate resulting column, COUNT(*) for example, need to get the first agg df column since names are not valid

* Remove SQL from OptimizationException

* skip tests that CBO plan reorganization causes missing features to be present

Co-authored-by: Charles Blackmon-Luca <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2022
1 parent 2a2b5d1 commit d233b9d
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 25 deletions.
4 changes: 4 additions & 0 deletions dask_planner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ fn rust(py: Python, m: &PyModule) -> PyResult<()> {
"DFParsingException",
py.get_type::<sql::exceptions::ParsingException>(),
)?;
m.add(
"DFOptimizationException",
py.get_type::<sql::exceptions::OptimizationException>(),
)?;

Ok(())
}
54 changes: 51 additions & 3 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@ pub mod column;
pub mod exceptions;
pub mod function;
pub mod logical;
pub mod optimizer;
pub mod schema;
pub mod statement;
pub mod table;
pub mod types;

use crate::sql::exceptions::ParsingException;
use crate::sql::exceptions::{OptimizationException, ParsingException};

use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::catalog::{ResolvedTableReference, TableReference};
use datafusion::datasource::TableProvider;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
AggregateUDF, ScalarFunctionImplementation, ScalarUDF, TableSource,
};
use datafusion::logical_plan::{LogicalPlan, PlanVisitor};
use datafusion::sql::parser::DFParser;
use datafusion::sql::planner::{ContextProvider, SqlToRel};

use std::collections::HashMap;
use std::sync::Arc;

use crate::sql::table::DaskTableSource;
use pyo3::prelude::*;

/// DaskSQLContext is main interface used for interacting with DataFusion to
Expand Down Expand Up @@ -177,4 +177,52 @@ impl DaskSQLContext {
})
.map_err(|e| PyErr::new::<ParsingException, _>(format!("{}", e)))
}

/// Accepts an existing relational plan, `LogicalPlan`, and optimizes it
/// by applying a set of `optimizer` trait implementations against the
/// `LogicalPlan`
pub fn optimize_relational_algebra(
&self,
existing_plan: logical::PyLogicalPlan,
) -> PyResult<logical::PyLogicalPlan> {
// Certain queries cannot be optimized. Ex: `EXPLAIN SELECT * FROM test` simply return those plans as is
let mut visitor = OptimizablePlanVisitor {};

match existing_plan.original_plan.accept(&mut visitor) {
Ok(valid) => {
if valid {
optimizer::DaskSqlOptimizer::new()
.run_optimizations(existing_plan.original_plan)
.map(|k| logical::PyLogicalPlan {
original_plan: k,
current_node: None,
})
.map_err(|e| PyErr::new::<OptimizationException, _>(format!("{}", e)))
} else {
// This LogicalPlan does not support Optimization. Return original
Ok(existing_plan)
}
}
Err(e) => Err(PyErr::new::<OptimizationException, _>(format!("{}", e))),
}
}
}

/// Visits each AST node to determine if the plan is valid for optimization or not
pub struct OptimizablePlanVisitor;

impl PlanVisitor for OptimizablePlanVisitor {
type Error = DataFusionError;

fn pre_visit(&mut self, plan: &LogicalPlan) -> std::result::Result<bool, DataFusionError> {
// If the plan contains an unsupported Node type we flag the plan as un-optimizable here
match plan {
LogicalPlan::Explain(..) => Ok(false),
_ => Ok(true),
}
}

fn post_visit(&mut self, _plan: &LogicalPlan) -> std::result::Result<bool, DataFusionError> {
Ok(true)
}
}
4 changes: 4 additions & 0 deletions dask_planner/src/sql/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ use datafusion::error::DataFusionError;
use pyo3::{create_exception, PyErr};
use std::fmt::Debug;

// Identifies expections that occur while attempting to generate a `LogicalPlan` from a SQL string
create_exception!(rust, ParsingException, pyo3::exceptions::PyException);

// Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan`
create_exception!(rust, OptimizationException, pyo3::exceptions::PyException);

pub fn py_type_err(e: impl Debug) -> PyErr {
PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!("{:?}", e))
}
Expand Down
4 changes: 2 additions & 2 deletions dask_planner/src/sql/logical/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::expression::PyExpr;
use crate::sql::column;

use datafusion::logical_expr::{
and, binary_expr,
and,
logical_plan::{Join, JoinType, LogicalPlan},
Expr, Operator,
Expr,
};

use crate::sql::exceptions::py_type_err;
Expand Down
54 changes: 54 additions & 0 deletions dask_planner/src/sql/optimizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use datafusion::error::DataFusionError;
use datafusion::logical_expr::LogicalPlan;
use datafusion::optimizer::eliminate_limit::EliminateLimit;
use datafusion::optimizer::filter_push_down::FilterPushDown;
use datafusion::optimizer::limit_push_down::LimitPushDown;
use datafusion::optimizer::optimizer::OptimizerRule;
use datafusion::optimizer::OptimizerConfig;

use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
use datafusion::optimizer::projection_push_down::ProjectionPushDown;
use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion::optimizer::subquery_filter_to_join::SubqueryFilterToJoin;

/// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations
/// and their ordering in regards to their impact on the underlying `LogicalPlan` instance
pub struct DaskSqlOptimizer {
optimizations: Vec<Box<dyn OptimizerRule + Send + Sync>>,
}

impl DaskSqlOptimizer {
/// Creates a new instance of the DaskSqlOptimizer with all the DataFusion desired
/// optimizers as well as any custom `OptimizerRule` trait impls that might be desired.
pub fn new() -> Self {
let mut rules: Vec<Box<dyn OptimizerRule + Send + Sync>> = Vec::new();
rules.push(Box::new(CommonSubexprEliminate::new()));
rules.push(Box::new(EliminateLimit::new()));
rules.push(Box::new(FilterPushDown::new()));
rules.push(Box::new(LimitPushDown::new()));
rules.push(Box::new(ProjectionPushDown::new()));
rules.push(Box::new(SingleDistinctToGroupBy::new()));
rules.push(Box::new(SubqueryFilterToJoin::new()));
Self {
optimizations: rules,
}
}

/// Iteratoes through the configured `OptimizerRule`(s) to transform the input `LogicalPlan`
/// to its final optimized form
pub(crate) fn run_optimizations(
&self,
plan: LogicalPlan,
) -> Result<LogicalPlan, DataFusionError> {
let mut resulting_plan: LogicalPlan = plan;
for optimization in &self.optimizations {
match optimization.optimize(&resulting_plan, &OptimizerConfig::new()) {
Ok(optimized_plan) => resulting_plan = optimized_plan,
Err(e) => {
return Err(e);
}
}
}
Ok(resulting_plan)
}
}
4 changes: 1 addition & 3 deletions dask_planner/src/sql/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use crate::sql::types::SqlTypeName;
use async_trait::async_trait;

use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{Expr, LogicalPlan, TableSource};
use datafusion::logical_expr::{LogicalPlan, TableSource};

use pyo3::prelude::*;

Expand Down
32 changes: 22 additions & 10 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from dask.base import optimize
from dask.distributed import Client

from dask_planner.rust import DaskSchema, DaskSQLContext, DaskTable, DFParsingException
from dask_planner.rust import (
DaskSchema,
DaskSQLContext,
DaskTable,
DFOptimizationException,
DFParsingException,
)

try:
import dask_cuda # noqa: F401
Expand All @@ -31,7 +37,7 @@
from dask_sql.mappings import python_to_sql_type
from dask_sql.physical.rel import RelConverter, custom, logical
from dask_sql.physical.rex import RexConverter, core
from dask_sql.utils import ParsingException
from dask_sql.utils import OptimizationException, ParsingException

if TYPE_CHECKING:
from dask_planner.rust import Expression
Expand Down Expand Up @@ -829,17 +835,23 @@ def _get_ral(self, sql):
except DFParsingException as pe:
raise ParsingException(sql, str(pe)) from None

rel = nonOptimizedRel
logger.debug(f"_get_ral -> nonOptimizedRelNode: {nonOptimizedRel}")
# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [field for field in rel.getRowType().getFieldList()]
# Optimize the `LogicalPlan` or skip if configured
if dask_config.get("sql.optimize"):
try:
rel = self.context.optimize_relational_algebra(nonOptimizedRel)
except DFOptimizationException as oe:
rel = nonOptimizedRel
raise OptimizationException(str(oe)) from None
else:
rel = nonOptimizedRel

# TODO: For POC we are not optimizing the relational algebra - Jeremy Dyer
# rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode)
# rel_string = str(generator.getRelationalAlgebraString(rel))
rel_string = rel.explain_original()

logger.debug(f"_get_ral -> LogicalPlan: {rel}")
logger.debug(f"Extracted relational algebra:\n {rel_string}")

# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [field for field in rel.getRowType().getFieldList()]

return rel, select_names, rel_string

def _get_tables_from_stack(self):
Expand Down
15 changes: 8 additions & 7 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,13 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai

# Fix the column names and the order of them, as this was messed with during the aggregations
df_agg.columns = df_agg.columns.get_level_values(-1)
backend_output_column_order = [
cc.get_backend_by_frontend_name(oc) for oc in output_column_order
]

if len(output_column_order) == 1 and output_column_order[0] == "UInt8(1)":
backend_output_column_order = [df_agg.columns[0]]
else:
backend_output_column_order = [
cc.get_backend_by_frontend_name(oc) for oc in output_column_order
]
cc = ColumnContainer(df_agg.columns).limit_to(backend_output_column_order)

cc = self.fix_column_to_row_type(cc, rel.getRowType())
Expand Down Expand Up @@ -425,7 +429,7 @@ def _perform_aggregation(
if additional_column_name is None:
additional_column_name = new_temporary_column(dc.df)

# perform groupby operation; if we are using custom aggreagations, we must handle
# perform groupby operation; if we are using custom aggregations, we must handle
# null values manually (this is slow)
if fast_groupby:
group_columns = [
Expand All @@ -448,11 +452,8 @@ def _perform_aggregation(

for col in agg_result.columns:
logger.debug(col)
logger.debug(f"agg_result: {agg_result.head()}")

# fix the column names to a single level
agg_result.columns = agg_result.columns.get_level_values(-1)

logger.debug(f"agg_result after: {agg_result.head()}")

return agg_result
5 changes: 5 additions & 0 deletions dask_sql/sql-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ properties:
type: boolean
description: |
Whether to try pushing down filter predicates into IO (when possible).
optimize:
type: boolean
description: |
Whether the first generated logical plan should be further optimized or used as is.
2 changes: 2 additions & 0 deletions dask_sql/sql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ sql:
case_sensitive: True

predicate_pushdown: True

optimize: True
13 changes: 13 additions & 0 deletions dask_sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ def __init__(self, sql, validation_exception_string):
super().__init__(validation_exception_string.strip())


class OptimizationException(Exception):
"""
Helper class for formatting exceptions that occur while trying to
optimize a logical plan
"""

def __init__(self, exception_string):
"""
Create a new exception out of the SQL query and the exception from DataFusion
"""
super().__init__(exception_string.strip())


class LoggableDataFrame:
"""Small helper class to print resulting dataframes or series in logging messages"""

Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def test_order_by_no_limit():
)


@pytest.mark.skip(
reason="WIP DataFusion - https://github.com/dask-contrib/dask-sql/issues/530"
)
def test_order_by_limit():
a = make_rand_df(100, a=(int, 50), b=(str, 50), c=float)
eq_sqlite(
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def test_join_cross(c, user_table_1, department_table):
assert_eq(return_df, expected_df, check_index=False)


@pytest.mark.skip(
reason="WIP DataFusion - Enabling CBO generates yet to be implemented edge case"
)
def test_join_complex(c):
return_df = c.sql(
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from tests.utils import assert_eq


@pytest.mark.skip(
reason="WIP DataFusion - Enabling CBO generates yet to be implemented edge case"
)
def test_case(c, df):
result_df = c.sql(
"""
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_limit(assert_query_gives_same_result):
)


@pytest.mark.skip(reason="WIP DataFusion")
def test_groupby(assert_query_gives_same_result):
assert_query_gives_same_result(
"""
Expand Down

0 comments on commit d233b9d

Please sign in to comment.