From 1b35a15cf81b800e957e22828b71cfd99cf42b51 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 2 Oct 2024 12:32:08 +0200 Subject: [PATCH] Move analyzer out of optimizer DataFusion is a SQL query engine and also a reusable library for building query engines. The optimizer part is generic, but analyzer's role is to support the DataFusion SQL frontend. Separate the concerns, so that the optimizer crate is truly reusable. --- datafusion-cli/Cargo.lock | 1 + datafusion-examples/examples/analyzer_rule.rs | 2 +- datafusion-examples/examples/sql_frontend.rs | 5 +- datafusion/core/src/datasource/view.rs | 7 +- datafusion/core/src/execution/context/mod.rs | 3 +- .../core/src/execution/session_state.rs | 6 +- datafusion/core/tests/optimizer/mod.rs | 2 +- .../tests/user_defined/user_defined_plan.rs | 2 +- datafusion/expr/src/expr.rs | 19 +- datafusion/optimizer/src/decorrelate.rs | 3 +- .../optimizer/src/eliminate_nested_union.rs | 4 +- datafusion/optimizer/src/lib.rs | 2 - datafusion/optimizer/src/test/mod.rs | 41 +-- datafusion/optimizer/src/utils.rs | 19 +- .../optimizer/tests/optimizer_integration.rs | 2 +- datafusion/sql/Cargo.toml | 1 + .../src/analyzer/count_wildcard_rule.rs | 7 +- .../src/analyzer/expand_wildcard_rule.rs | 7 +- .../src/analyzer/function_rewrite.rs | 3 +- .../src/analyzer/inline_table_scan.rs | 2 +- .../{optimizer => sql}/src/analyzer/mod.rs | 9 +- .../src/analyzer/resolve_grouping_function.rs | 247 ++++++++++++++++++ .../src/analyzer/subquery.rs | 2 +- .../src/analyzer/type_coercion.rs | 3 +- datafusion/sql/src/lib.rs | 3 + datafusion/sql/src/test.rs | 80 ++++++ 26 files changed, 387 insertions(+), 95 deletions(-) rename datafusion/{optimizer => sql}/src/analyzer/count_wildcard_rule.rs (98%) rename datafusion/{optimizer => sql}/src/analyzer/expand_wildcard_rule.rs (99%) rename datafusion/{optimizer => sql}/src/analyzer/function_rewrite.rs (97%) rename datafusion/{optimizer => sql}/src/analyzer/inline_table_scan.rs (100%) rename datafusion/{optimizer => sql}/src/analyzer/mod.rs (96%) create mode 100644 datafusion/sql/src/analyzer/resolve_grouping_function.rs rename datafusion/{optimizer => sql}/src/analyzer/subquery.rs (99%) rename datafusion/{optimizer => sql}/src/analyzer/type_coercion.rs (99%) create mode 100644 datafusion/sql/src/test.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ca67e3e4f531..f544c04f51c1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1574,6 +1574,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "indexmap", + "itertools", "log", "regex", "sqlparser", diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs index bd067be97b8b..bd835ac95c65 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -21,7 +21,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::Result; use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; -use datafusion_optimizer::analyzer::AnalyzerRule; +use datafusion_sql::analyzer::AnalyzerRule; use std::sync::{Arc, Mutex}; /// This example demonstrates how to add your own [`AnalyzerRule`] to diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index 839ee95eb181..2f8d56a66acd 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -22,9 +22,8 @@ use datafusion_expr::{ AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, WindowUDF, }; -use datafusion_optimizer::{ - Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, -}; +use datafusion_optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_sql::analyzer::{Analyzer, AnalyzerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; use datafusion_sql::sqlparser::parser::Parser; diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 1ffe54e4b06c..6b139d01dcbf 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -19,6 +19,7 @@ use std::{any::Any, borrow::Cow, sync::Arc}; +use crate::datasource::{TableProvider, TableType}; use crate::{ error::Result, logical_expr::{Expr, LogicalPlan}, @@ -30,10 +31,8 @@ use datafusion_catalog::Session; use datafusion_common::config::ConfigOptions; use datafusion_common::Column; use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; -use datafusion_optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; -use datafusion_optimizer::Analyzer; - -use crate::datasource::{TableProvider, TableType}; +use datafusion_sql::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion_sql::analyzer::Analyzer; /// An implementation of `TableProvider` that uses another logical plan. #[derive(Debug)] diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 606759aae5ee..4ecf44fc853f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -77,7 +77,8 @@ use datafusion_catalog::{DynamicFileCatalog, SessionStore, UrlTableFactory}; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; -use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; +use datafusion_optimizer::OptimizerRule; +use datafusion_sql::analyzer::AnalyzerRule; use object_store::ObjectStore; use parking_lot::RwLock; use url::Url; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 516a0e700a7b..2631be9535ed 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -56,14 +56,12 @@ use datafusion_expr::{ AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; -use datafusion_optimizer::analyzer::type_coercion::TypeCoercionRewriter; -use datafusion_optimizer::{ - Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, -}; +use datafusion_optimizer::{Optimizer, OptimizerConfig, OptimizerRule}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; +use datafusion_sql::analyzer::{Analyzer, AnalyzerRule}; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; use itertools::Itertools; diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index f17d13a42060..08860bdcd05e 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -33,7 +33,6 @@ use datafusion_expr::{ ScalarUDF, TableSource, WindowUDF, }; use datafusion_functions::core::expr_ext::FieldAccessor; -use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; @@ -45,6 +44,7 @@ use datafusion_sql::TableReference; use chrono::DateTime; use datafusion_functions::datetime; +use datafusion_sql::analyzer::Analyzer; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 6c4e3c66e397..8b100073a841 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -100,7 +100,7 @@ use datafusion_common::ScalarValue; use datafusion_expr::tree_node::replace_sort_expression; use datafusion_expr::{FetchType, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; -use datafusion_optimizer::AnalyzerRule; +use datafusion_sql::analyzer::AnalyzerRule; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4d73c2a04486..8e72b71c1203 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,7 +17,7 @@ //! Logical Expressions: [`Expr`] -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; @@ -1830,6 +1830,23 @@ fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Resu Ok(()) } +pub fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: &DFSchema, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.column_refs().into_iter() { + if subquery_schema.has_column(col) { + using_cols.push(col.clone()); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + #[macro_export] macro_rules! expr_vec_fmt { ( $ARRAY:expr ) => {{ diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 6aa59b77f7f9..277714eff849 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -22,13 +22,12 @@ use std::ops::Deref; use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; -use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{collect_subquery_cols, Alias}; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{ diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78..7aaeb6fa34c4 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -114,12 +114,12 @@ fn extract_plan_from_distinct(plan: Arc) -> Arc { #[cfg(test)] mod tests { use super::*; - use crate::analyzer::type_coercion::TypeCoercion; - use crate::analyzer::Analyzer; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; + use datafusion_sql::analyzer::type_coercion::TypeCoercion; + use datafusion_sql::analyzer::Analyzer; fn schema() -> Schema { Schema::new(vec![ diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index f31083831125..88640bc7b404 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -30,7 +30,6 @@ //! //! [`LogicalPlan`]: datafusion_expr::LogicalPlan //! [`TypeCoercion`]: analyzer::type_coercion::TypeCoercion -pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; pub mod decorrelate_predicate_subquery; @@ -60,7 +59,6 @@ pub mod utils; #[cfg(test)] pub mod test; -pub use analyzer::{Analyzer, AnalyzerRule}; pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; #[allow(deprecated)] pub use utils::optimize_children; diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 94d07a0791b3..f1503402f8c0 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::analyzer::{Analyzer, AnalyzerRule}; use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, Result}; use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_sql::analyzer::{Analyzer, AnalyzerRule}; use std::sync::Arc; pub mod user_defined; @@ -108,45 +108,6 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } -pub fn assert_analyzed_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; - - Ok(()) -} - -pub fn assert_analyzed_plan_with_config_eq( - options: ConfigOptions, - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_eq!(formatted_plan, expected); - - Ok(()) -} - -pub fn assert_analyzed_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = analyzed_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); - - Ok(()) -} - pub fn assert_analyzer_check_err( rules: Vec>, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6972c16c0ddf..1d886239b7b6 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, Result}; +use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::{logical_plan::LogicalPlan, Expr}; @@ -80,23 +80,6 @@ pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> == column_refs.len() } -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: &DFSchema, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.column_refs().into_iter() { - if subquery_schema.has_column(col) { - using_cols.push(col.clone()); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - pub(crate) fn replace_qualified_name( expr: Expr, cols: &BTreeSet, diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 236167985790..f8a8782433be 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -27,9 +27,9 @@ use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; -use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; +use datafusion_sql::analyzer::Analyzer; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 90be576a884e..99b8c879ff41 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -47,6 +47,7 @@ arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } indexmap = { workspace = true } +itertools = { workspace = true } log = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/sql/src/analyzer/count_wildcard_rule.rs similarity index 98% rename from datafusion/optimizer/src/analyzer/count_wildcard_rule.rs rename to datafusion/sql/src/analyzer/count_wildcard_rule.rs index b3b24724552a..bb5d7009a6db 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/sql/src/analyzer/count_wildcard_rule.rs @@ -17,11 +17,11 @@ use crate::analyzer::AnalyzerRule; -use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, WindowFunction}; +use datafusion_expr::expr_rewriter::NamePreserver; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; @@ -95,7 +95,6 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { #[cfg(test)] mod tests { use super::*; - use crate::test::*; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; @@ -108,6 +107,10 @@ mod tests { use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; + use crate::test::{ + assert_analyzed_plan_eq_display_indent, test_table_scan, + test_table_scan_with_name, + }; use datafusion_functions_aggregate::expr_fn::{count, sum}; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/sql/src/analyzer/expand_wildcard_rule.rs similarity index 99% rename from datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs rename to datafusion/sql/src/analyzer/expand_wildcard_rule.rs index 9fbe54e1ccb9..a0aba671af48 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/sql/src/analyzer/expand_wildcard_rule.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::AnalyzerRule; +use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{Column, Result}; @@ -175,15 +175,14 @@ fn replace_columns( mod tests { use arrow::datatypes::{DataType, Field, Schema}; + use super::*; + use crate::analyzer::Analyzer; use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; - use crate::Analyzer; use datafusion_common::{JoinType, TableReference}; use datafusion_expr::{ col, in_subquery, qualified_wildcard, table_scan, wildcard, LogicalPlanBuilder, }; - use super::*; - fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(ExpandWildcardRule::new()), diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/sql/src/analyzer/function_rewrite.rs similarity index 97% rename from datafusion/optimizer/src/analyzer/function_rewrite.rs rename to datafusion/sql/src/analyzer/function_rewrite.rs index c6bf14ebce2e..0321b90b11d1 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/sql/src/analyzer/function_rewrite.rs @@ -22,8 +22,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; -use crate::utils::NamePreserver; -use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::expr_rewriter::{FunctionRewrite, NamePreserver}; use datafusion_expr::utils::merge_schema; use datafusion_expr::LogicalPlan; use std::sync::Arc; diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/sql/src/analyzer/inline_table_scan.rs similarity index 100% rename from datafusion/optimizer/src/analyzer/inline_table_scan.rs rename to datafusion/sql/src/analyzer/inline_table_scan.rs index 342d85a915b4..404bc3cacad9 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/sql/src/analyzer/inline_table_scan.rs @@ -106,8 +106,8 @@ mod tests { use std::{borrow::Cow, sync::Arc, vec}; use crate::analyzer::inline_table_scan::InlineTableScan; - use crate::test::assert_analyzed_plan_eq; + use crate::test::assert_analyzed_plan_eq; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder, TableSource}; diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/sql/src/analyzer/mod.rs similarity index 96% rename from datafusion/optimizer/src/analyzer/mod.rs rename to datafusion/sql/src/analyzer/mod.rs index a9fd4900b2f4..6245c9592fb8 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/sql/src/analyzer/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; -use log::debug; +use log::{debug, trace}; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; @@ -37,7 +37,6 @@ use crate::analyzer::inline_table_scan::InlineTableScan; use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; -use crate::utils::log_plan; use self::function_rewrite::ApplyFunctionRewrites; @@ -194,3 +193,9 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { }) .map(|_| ()) } + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} diff --git a/datafusion/sql/src/analyzer/resolve_grouping_function.rs b/datafusion/sql/src/analyzer/resolve_grouping_function.rs new file mode 100644 index 000000000000..16ebb8cd3972 --- /dev/null +++ b/datafusion/sql/src/analyzer/resolve_grouping_function.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzed rule to replace TableScan references +//! such as DataFrames and Views and inlines the LogicalPlan. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::analyzer::AnalyzerRule; + +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, +}; +use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::grouping_set_to_exprlist; +use datafusion_expr::{ + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, + Expr, Projection, +}; +use itertools::Itertools; + +/// Replaces grouping aggregation function with value derived from internal grouping id +#[derive(Default, Debug)] +pub struct ResolveGroupingFunction; + +impl ResolveGroupingFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ResolveGroupingFunction { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + plan.transform_up(analyze_internal).data() + } + + fn name(&self) -> &str { + "resolve_grouping_function" + } +} + +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[Aggregate::INTERNAL_GROUPING_ID]] +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + +fn replace_grouping_exprs( + input: Arc, + schema: DFSchemaRef, + group_expr: Vec, + aggr_expr: Vec, +) -> Result { + // Create HashMap from Expr to index in the grouping_id bitmap + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; + let columns = schema.columns(); + let mut new_agg_expr = Vec::new(); + let mut projection_exprs = Vec::new(); + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; + projection_exprs.extend( + columns + .iter() + .take(group_expr_len) + .map(|column| Expr::Column(column.clone())), + ); + for (expr, column) in aggr_expr + .into_iter() + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) + { + match expr { + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + column.relation, + column.name, + ))); + } + _ => { + projection_exprs.push(Expr::Column(column)); + new_agg_expr.push(expr); + } + } + } + // Recreate aggregate without grouping functions + let new_aggregate = + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + // Create projection with grouping functions calculations + let projection = LogicalPlan::Projection(Projection::try_new( + projection_exprs, + new_aggregate.into(), + )?); + Ok(projection) +} + +fn analyze_internal(plan: LogicalPlan) -> Result> { + // rewrite any subqueries in the plan first + let transformed_plan = + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; + + let transformed_plan = transformed_plan.transform_data(|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, + )), + _ => Ok(Transformed::no(plan)), + })?; + + Ok(transformed_plan) +} + +fn is_grouping_function(expr: &Expr) -> bool { + // TODO: Do something better than name here should grouping be a built + // in expression? + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") +} + +fn contains_grouping_function(exprs: &[Expr]) -> bool { + exprs.iter().any(is_grouping_function) +} + +/// Validate that the arguments to the grouping function are in the group by clause. +fn validate_args( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, +) -> Result<()> { + let expr_not_in_group_by = function + .args + .iter() + .find(|expr| !group_by_expr.contains_key(expr)); + if let Some(expr) = expr_not_in_group_by { + plan_err!( + "Argument {} to grouping function is not in grouping columns {}", + expr, + group_by_expr.keys().map(|e| e.to_string()).join(", ") + ) + } else { + Ok(()) + } +} + +fn grouping_function_on_id( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, + is_grouping_set: bool, +) -> Result { + validate_args(function, group_by_expr)?; + let args = &function.args; + + // Postgres allows grouping function for group by without grouping sets, the result is then + // always 0 + if !is_grouping_set { + return Ok(Expr::Literal(ScalarValue::from(0i32))); + } + + let group_by_expr_count = group_by_expr.len(); + let literal = |value: usize| { + if group_by_expr_count < 8 { + Expr::Literal(ScalarValue::from(value as u8)) + } else if group_by_expr_count < 16 { + Expr::Literal(ScalarValue::from(value as u16)) + } else if group_by_expr_count < 32 { + Expr::Literal(ScalarValue::from(value as u32)) + } else { + Expr::Literal(ScalarValue::from(value as u64)) + } + }; + + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + // The grouping call is exactly our internal grouping id + if args.len() == group_by_expr_count + && args + .iter() + .rev() + .enumerate() + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) + { + return Ok(cast(grouping_id_column, DataType::Int32)); + } + + args.iter() + .rev() + .enumerate() + .map(|(arg_idx, expr)| { + group_by_expr.get(expr).map(|group_by_idx| { + let group_by_bit = + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); + match group_by_idx.cmp(&arg_idx) { + Ordering::Less => { + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) + } + Ordering::Greater => { + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) + } + Ordering::Equal => group_by_bit, + } + }) + }) + .collect::>>() + .and_then(|bit_exprs| { + bit_exprs + .into_iter() + .reduce(bitwise_or) + .map(|expr| cast(expr, DataType::Int32)) + }) + .ok_or_else(|| { + internal_datafusion_err!("Grouping sets should contains at least one element") + }) +} diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/sql/src/analyzer/subquery.rs similarity index 99% rename from datafusion/optimizer/src/analyzer/subquery.rs rename to datafusion/sql/src/analyzer/subquery.rs index e01ae625ed9c..8ead8368bf04 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/sql/src/analyzer/subquery.rs @@ -16,10 +16,10 @@ // under the License. use crate::analyzer::check_plan; -use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; +use datafusion_expr::expr::collect_subquery_cols; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window}; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/sql/src/analyzer/type_coercion.rs similarity index 99% rename from datafusion/optimizer/src/analyzer/type_coercion.rs rename to datafusion/sql/src/analyzer/type_coercion.rs index 67c120d3f66b..d23f13b7b100 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/sql/src/analyzer/type_coercion.rs @@ -24,7 +24,6 @@ use itertools::izip; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; -use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ @@ -35,7 +34,7 @@ use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, Sort, WindowFunction, }; -use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::expr_rewriter::{coerce_plan_expr_for_schema, NamePreserver}; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 956f5e17e26f..b705c175cc3b 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -34,6 +34,7 @@ //! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan //! [`Expr`]: datafusion_expr::expr::Expr +pub mod analyzer; mod cte; mod expr; pub mod parser; @@ -43,6 +44,8 @@ mod relation; mod select; mod set_expr; mod statement; +#[cfg(test)] +mod test; #[cfg(feature = "unparser")] pub mod unparser; pub mod utils; diff --git a/datafusion/sql/src/test.rs b/datafusion/sql/src/test.rs new file mode 100644 index 000000000000..b23626db7dab --- /dev/null +++ b/datafusion/sql/src/test.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::{Analyzer, AnalyzerRule}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{table_scan, LogicalPlan}; +use std::sync::Arc; + +/// some tests share a common table +pub fn test_table_scan() -> datafusion_common::Result { + test_table_scan_with_name("test") +} + +/// some tests share a common table with different names +pub fn test_table_scan_with_name(name: &str) -> datafusion_common::Result { + let schema = Schema::new(test_table_scan_fields()); + table_scan(Some(name), &schema, None)?.build() +} + +pub fn test_table_scan_fields() -> Vec { + vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ] +} + +pub fn assert_analyzed_plan_eq( + rule: Arc, + plan: LogicalPlan, + expected: &str, +) -> datafusion_common::Result<()> { + let options = ConfigOptions::default(); + assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; + + Ok(()) +} + +pub fn assert_analyzed_plan_with_config_eq( + options: ConfigOptions, + rule: Arc, + plan: LogicalPlan, + expected: &str, +) -> datafusion_common::Result<()> { + let analyzed_plan = + Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; + let formatted_plan = format!("{analyzed_plan}"); + assert_eq!(formatted_plan, expected); + + Ok(()) +} + +pub fn assert_analyzed_plan_eq_display_indent( + rule: Arc, + plan: LogicalPlan, + expected: &str, +) -> datafusion_common::Result<()> { + let options = ConfigOptions::default(); + let analyzed_plan = + Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; + let formatted_plan = analyzed_plan.display_indent_schema().to_string(); + assert_eq!(formatted_plan, expected); + + Ok(()) +}