diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 27fdc9e37037..49ce70a5c0bc 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -505,12 +505,11 @@ jobs: - name: Check Cargo.toml formatting run: | - # if you encounter error, try rerun the command below, finally run 'git diff' to - # check which Cargo.toml introduces formatting violation + # if you encounter an error, try running 'cargo tomlfmt -p path/to/Cargo.toml' to fix the formatting automatically. + # If the error still persists, you need to manually edit the Cargo.toml file, which introduces formatting violation. # # ignore ./Cargo.toml because putting workspaces in multi-line lists make it easy to read ci/scripts/rust_toml_fmt.sh - git diff --exit-code config-docs-check: name: check configs.md is up-to-date diff --git a/ci/scripts/rust_toml_fmt.sh b/ci/scripts/rust_toml_fmt.sh index e297ef001594..0a8cc346a37d 100755 --- a/ci/scripts/rust_toml_fmt.sh +++ b/ci/scripts/rust_toml_fmt.sh @@ -17,5 +17,11 @@ # specific language governing permissions and limitations # under the License. +# Run cargo-tomlfmt with flag `-d` in dry run to check formatting +# without overwritng the file. If any error occur, you may want to +# rerun 'cargo tomlfmt -p path/to/Cargo.toml' without '-d' to fix +# the formatting automatically. set -ex -find . -mindepth 2 -name 'Cargo.toml' -exec cargo tomlfmt -p {} \; +for toml in $(find . -mindepth 2 -name 'Cargo.toml'); do + cargo tomlfmt -d -p $toml +done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f695e81d9876..1c872c28485c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ "typed-builder", "uuid", "xz2", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -370,8 +370,8 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd", - "zstd-safe", + "zstd 0.12.4", + "zstd-safe 6.0.6", ] [[package]] @@ -1141,7 +1141,7 @@ dependencies = [ "url", "uuid", "xz2", - "zstd", + "zstd 0.13.0", ] [[package]] @@ -1232,7 +1232,7 @@ dependencies = [ "hashbrown 0.14.1", "itertools", "log", - "regex-syntax 0.8.0", + "regex-syntax 0.8.1", ] [[package]] @@ -2319,9 +2319,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "ordered-float" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" dependencies = [ "num-traits", ] @@ -2392,7 +2392,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -2681,7 +2681,7 @@ dependencies = [ "aho-corasick", "memchr", "regex-automata", - "regex-syntax 0.8.0", + "regex-syntax 0.8.1", ] [[package]] @@ -2692,7 +2692,7 @@ checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.0", + "regex-syntax 0.8.1", ] [[package]] @@ -2709,9 +2709,9 @@ checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "regex-syntax" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3cbb081b9784b07cceb8824c8583f86db4814d172ab043f3c23f7dc600bf83d" +checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33" [[package]] name = "reqwest" @@ -3933,7 +3933,16 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe 7.0.0", ] [[package]] @@ -3946,13 +3955,21 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" -version = "2.0.8+zstd.1.5.5" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" dependencies = [ "cc", - "libc", "pkg-config", ] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 484d203be848..fa580c914ce2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -90,7 +90,7 @@ tokio-util = { version = "0.7.4", features = ["io"] } url = "2.2" uuid = { version = "1.0", features = ["v4"] } xz2 = { version = "0.1", optional = true } -zstd = { version = "0.12", optional = true, default-features = false } +zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index a1f509d28733..9be566f10a72 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1738,8 +1738,11 @@ mod tests { _t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortRequiredExec: [{}]", expr.join(",")) + write!( + f, + "SortRequiredExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + ) } } @@ -3056,16 +3059,16 @@ mod tests { vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3082,21 +3085,21 @@ mod tests { _ => vec![ top_join_plan.as_str(), // Below 4 operators are differences introduced, when join mode is changed - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3170,16 +3173,16 @@ mod tests { JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3187,21 +3190,21 @@ mod tests { // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, sort_exprs=b1@6 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@6 ASC]", "CoalescePartitionsExec", join_plan.as_str(), - "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[a@0 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b1@1 ASC]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[c@2 ASC]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -3292,7 +3295,7 @@ mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, sort_exprs=b3@1 ASC,a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b3@1 ASC,a3@0 ASC]", "CoalescePartitionsExec", @@ -3303,7 +3306,7 @@ mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10", + "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, sort_exprs=b2@1 ASC,a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "SortExec: expr=[b2@1 ASC,a2@0 ASC]", "CoalescePartitionsExec", @@ -4382,7 +4385,7 @@ mod tests { let expected = &[ "SortPreservingMergeExec: [c@2 ASC]", "FilterExec: c@2 = 0", - "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c@2 ASC", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 95ec1973d017..92db3bbd053e 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -62,6 +62,7 @@ use datafusion_physical_expr::utils::{ }; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion_physical_plan::repartition::RepartitionExec; use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -566,7 +567,6 @@ fn analyze_window_sort_removal( ); let mut window_child = remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; - let (window_expr, new_window) = if let Some(exec) = window_exec.as_any().downcast_ref::() { ( @@ -608,12 +608,10 @@ fn analyze_window_sort_removal( add_sort_above(&mut window_child, sort_expr, None)?; let uses_bounded_memory = window_expr.iter().all(|e| e.uses_bounded_memory()); - let input_schema = window_child.schema(); let new_window = if uses_bounded_memory { Arc::new(BoundedWindowAggExec::try_new( window_expr.to_vec(), window_child, - input_schema, partitionby_exprs.to_vec(), PartitionSearchMode::Sorted, )?) as _ @@ -621,7 +619,6 @@ fn analyze_window_sort_removal( Arc::new(WindowAggExec::try_new( window_expr.to_vec(), window_child, - input_schema, partitionby_exprs.to_vec(), )?) as _ }; @@ -704,8 +701,18 @@ fn remove_corresponding_sort_from_sub_plan( children[item.idx] = remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; } + // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { children[0].clone() + } else if let Some(repartition) = plan.as_any().downcast_ref::() + { + Arc::new( + RepartitionExec::try_new( + children[0].clone(), + repartition.partitioning().clone(), + )? + .with_preserve_order(false), + ) } else { plan.clone().with_new_children(children)? } @@ -758,7 +765,7 @@ mod tests { coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, spr_repartition_exec, union_exec, }; use crate::physical_optimizer::utils::get_plan_string; use crate::physical_plan::repartition::RepartitionExec; @@ -1635,14 +1642,16 @@ mod tests { // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", " UnionExec", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + let expected_optimized = [ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", @@ -2147,15 +2156,19 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC]", + let expected_input = [ + "SortExec: expr=[a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2177,11 +2190,14 @@ mod tests { " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false"]; + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) } @@ -2234,4 +2250,36 @@ mod tests { assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) } + + #[tokio::test] + async fn test_window_multi_layer_requirement() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let source = csv_exec_sorted(&schema, vec![], false); + let sort = sort_exec(sort_exprs.clone(), source); + let repartition = repartition_exec(sort); + let repartition = spr_repartition_exec(repartition); + let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); + + let physical_plan = bounded_window_exec("a", sort_exprs, spm); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, sort_exprs=a@0 ASC,b@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index ede95fc67721..f4b3608d00c7 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -368,7 +368,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -412,10 +412,10 @@ mod tests { let expected_optimized = [ "SortPreservingMergeExec: [a@0 ASC]", " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortPreservingMergeExec: [a@0 ASC]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", ]; @@ -442,11 +442,14 @@ mod tests { " FilterExec: c@2 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " FilterExec: c@2 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -475,7 +478,7 @@ mod tests { let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -509,7 +512,7 @@ mod tests { let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; @@ -561,21 +564,25 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - let expected_input = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", " SortExec: expr=[a@0 ASC NULLS LAST]", " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@2 > 3", " RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true" + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } @@ -627,7 +634,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan); @@ -671,7 +678,7 @@ mod tests { let expected_optimized = [ "SortPreservingMergeExec: [c@2 ASC]", " FilterExec: c@2 > 3", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=c@2 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " SortExec: expr=[c@2 ASC]", " CoalescePartitionsExec", @@ -756,7 +763,7 @@ mod tests { " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 0915fdbf1cd7..ed7345651457 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -238,7 +238,6 @@ pub fn bounded_window_exec( ) .unwrap()], input.clone(), - input.schema(), vec![], crate::physical_plan::windows::PartitionSearchMode::Sorted, ) @@ -324,6 +323,14 @@ pub fn repartition_exec(input: Arc) -> Arc Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) } +pub fn spr_repartition_exec(input: Arc) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(true), + ) +} + pub fn aggregate_exec(input: Arc) -> Arc { let schema = input.schema(); Arc::new( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 35119f374fa3..4055f9f4ebd2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -751,7 +751,6 @@ impl DefaultPhysicalPlanner { Arc::new(BoundedWindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, PartitionSearchMode::Sorted, )?) @@ -759,7 +758,6 @@ impl DefaultPhysicalPlanner { Arc::new(WindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, )?) }) @@ -1251,10 +1249,10 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Prepare" ) } - LogicalPlan::Dml(_) => { + LogicalPlan::Dml(dml) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this not_impl_err!( - "Unsupported logical plan: Dml" + "Unsupported logical plan: Dml({0})", dml.op ) } LogicalPlan::Statement(statement) => { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 1f0a4b09b15f..83c8e1f57896 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -461,7 +461,6 @@ async fn run_window_test( ) .unwrap()], exec1, - schema.clone(), vec![], ) .unwrap(), @@ -484,7 +483,6 @@ async fn run_window_test( ) .unwrap()], exec2, - schema.clone(), vec![], search_mode, ) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e4e3068977c..3949d25b3025 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -165,9 +165,15 @@ pub enum Expr { InSubquery(InSubquery), /// Scalar subquery ScalarSubquery(Subquery), - /// Represents a reference to all fields in a schema. + /// Represents a reference to all available fields. + /// + /// This expr has to be resolved to a list of columns before translating logical + /// plan into physical plan. Wildcard, - /// Represents a reference to all fields in a specific schema. + /// Represents a reference to all available fields in a specific schema. + /// + /// This expr has to be resolved to a list of columns before translating logical + /// plan into physical plan. QualifiedWildcard { qualifier: String }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 8397dd3aac34..4f810fcdc04f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -659,13 +659,11 @@ impl LogicalPlanBuilder { /// /// let right_plan = LogicalPlanBuilder::scan("right", right_table, None)?.build()?; /// - /// // Form the expression `(left.a != right.a AND left.b != right.b)` + /// // Form the expression `(left.a != right.a)` AND `(left.b != right.b)` /// let exprs = vec![ /// col("left.a").eq(col("right.a")), /// col("left.b").not_eq(col("right.b")) - /// ] - /// .into_iter() - /// .reduce(Expr::and); + /// ]; /// /// // Perform the equivalent of `left INNER JOIN right ON (a != a2 AND b != b2)` /// // finding all pairs of rows from `left` and `right` where @@ -680,13 +678,15 @@ impl LogicalPlanBuilder { self, right: LogicalPlan, join_type: JoinType, - on_exprs: Option, + on_exprs: impl IntoIterator, ) -> Result { + let filter = on_exprs.into_iter().reduce(Expr::and); + self.join_detailed( right, join_type, (Vec::::new(), Vec::::new()), - on_exprs, + filter, false, ) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index dfc83f9eec76..b865b6855724 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -142,7 +142,7 @@ pub enum LogicalPlan { Prepare(Prepare), /// Data Manipulaton Language (DML): Insert / Update / Delete Dml(DmlStatement), - /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAs + /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), /// `COPY TO` for writing plan results to files Copy(CopyTo), diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 432d7f053aef..96b46663d8e4 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -21,7 +21,7 @@ use crate::utils::{conjunction, replace_qualified_name, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::TreeNode; -use datafusion_common::{plan_err, Column, DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; @@ -282,12 +282,7 @@ fn build_join( false => JoinType::LeftSemi, }; let new_plan = LogicalPlanBuilder::from(left.clone()) - .join( - sub_query_alias, - join_type, - (Vec::::new(), Vec::::new()), - Some(join_filter), - )? + .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; debug!( "predicate subquery optimized:\n{}", diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 00abcdcc68aa..0dbebcc8a051 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -77,7 +77,7 @@ impl OptimizerRule for EliminateJoin { mod tests { use crate::eliminate_join::EliminateJoin; use crate::test::*; - use datafusion_common::{Column, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; @@ -89,10 +89,9 @@ mod tests { #[test] fn join_on_false() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(false)))), )? .build()?; @@ -104,10 +103,9 @@ mod tests { #[test] fn join_on_true() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(true)))), )? .build()?; diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index e22c73e5794d..89bcc90bc075 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -16,12 +16,11 @@ // under the License. //! Optimizer rule to replace nested unions to single union. +use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; -use datafusion_expr::logical_plan::{LogicalPlan, Union}; - -use crate::optimizer::ApplyOrder; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::{Distinct, LogicalPlan, Union}; use std::sync::Arc; #[derive(Default)] @@ -41,22 +40,11 @@ impl OptimizerRule for EliminateNestedUnion { plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - // TODO: Add optimization for nested distinct unions. match plan { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs .iter() - .flat_map(|plan| match plan.as_ref() { - LogicalPlan::Union(Union { inputs, schema }) => inputs - .iter() - .map(|plan| { - Arc::new( - coerce_plan_expr_for_schema(plan, schema).unwrap(), - ) - }) - .collect::>(), - _ => vec![plan.clone()], - }) + .flat_map(extract_plans_from_union) .collect::>(); Ok(Some(LogicalPlan::Union(Union { @@ -64,6 +52,23 @@ impl OptimizerRule for EliminateNestedUnion { schema: schema.clone(), }))) } + LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Distinct(Distinct { + input: Arc::new(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + })), + }))) + } + _ => Ok(None), + }, _ => Ok(None), } } @@ -77,6 +82,23 @@ impl OptimizerRule for EliminateNestedUnion { } } +fn extract_plans_from_union(plan: &Arc) -> Vec> { + match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .collect::>(), + _ => vec![plan.clone()], + } +} + +fn extract_plan_from_distinct(plan: &Arc) -> &Arc { + match plan.as_ref() { + LogicalPlan::Distinct(Distinct { input: plan }) => plan, + _ => plan, + } +} + #[cfg(test)] mod tests { use super::*; @@ -112,6 +134,22 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_distinct_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn eliminate_nested_union() -> Result<()> { let plan_builder = table_scan(Some("table"), &schema(), None)?; @@ -132,6 +170,69 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_nested_union_with_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "Union\ + \n Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().distinct()?.build()?)? + .union(plan_builder.clone().distinct()?.build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + // We don't need to use project_with_column_index in logical optimizer, // after LogicalPlanBuilder::union, we already have all equal expression aliases #[test] @@ -163,6 +264,36 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn eliminate_nested_distinct_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { let table_1 = table_scan( @@ -208,4 +339,51 @@ mod tests { \n TableScan: table_1"; assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union_distinct(table_2.build()?)? + .union_distinct(table_3.build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index e328eeeb00a1..575969fbf73c 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -161,7 +161,6 @@ mod tests { use super::*; use crate::test::*; use arrow::datatypes::DataType; - use datafusion_common::Column; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; @@ -182,12 +181,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(col("t1.a").eq(col("t2.a"))), - )? + .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ @@ -202,10 +196,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; @@ -222,10 +215,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( (col("t1.a") + lit(10i64)) .gt_eq(col("t2.a") * lit(2u32)) @@ -273,10 +265,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.c") .eq(col("t2.c")) @@ -301,10 +292,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -313,10 +303,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.a") .eq(col("t2.a")) @@ -340,10 +329,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -352,10 +340,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; @@ -383,12 +370,7 @@ mod tests { ) .alias("t1.a + 1 = t2.a + 2"); let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(filter), - )? + .join_on(t2, JoinType::Left, Some(filter))? .build()?; let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 96d2f45d808e..7ac0c25119c3 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -315,24 +315,14 @@ fn build_join( _ => { // if not correlated, group down to 1 row and left join on that (preserving row count) LogicalPlanBuilder::from(filter_input.clone()) - .join( - sub_query_alias, - JoinType::Left, - (Vec::::new(), Vec::::new()), - None, - )? + .join_on(sub_query_alias, JoinType::Left, None)? .build()? } } } else { // left join if correlated, grouping by the join keys so we don't change row count LogicalPlanBuilder::from(filter_input.clone()) - .join( - sub_query_alias, - JoinType::Left, - (Vec::::new(), Vec::::new()), - join_filter_opt, - )? + .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? .build()? }; let mut computation_project_expr = HashMap::new(); diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 83d32dfeec17..74179ba5947c 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -17,6 +17,7 @@ //! Sort expressions +use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -89,6 +90,26 @@ impl PhysicalSortExpr { .options .map_or(true, |opts| self.options == opts) } + + /// Returns a [`Display`]able list of `PhysicalSortExpr`. + pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [PhysicalSortExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for sort_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", sort_expr)?; + } + Ok(()) + } + } + DisplayableList(input) + } } /// Represents sort requirement associated with a plan diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 4c83ff1528fa..d919ded8d0e2 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -77,9 +77,10 @@ impl DisplayAs for MemoryExec { .sort_information .first() .map(|output_ordering| { - let order_strings: Vec<_> = - output_ordering.iter().map(|e| e.to_string()).collect(); - format!(", output_ordering={}", order_strings.join(",")) + format!( + ", output_ordering={}", + PhysicalSortExpr::format_list(output_ordering) + ) }) .unwrap_or_default(); diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 14b54dc0614d..bcb9c3afeef1 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -308,7 +308,8 @@ pub struct RepartitionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Boolean flag to decide whether to preserve ordering + /// Boolean flag to decide whether to preserve ordering. If true means + /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. preserve_order: bool, } @@ -370,7 +371,7 @@ impl RepartitionExec { self.preserve_order } - /// Get name of the Executor + /// Get name used to display this Exec pub fn name(&self) -> &str { if self.preserve_order { "SortPreservingRepartitionExec" @@ -394,7 +395,16 @@ impl DisplayAs for RepartitionExec { self.name(), self.partitioning, self.input.output_partitioning().partition_count() - ) + )?; + + if let Some(sort_exprs) = self.sort_exprs() { + write!( + f, + ", sort_exprs={}", + PhysicalSortExpr::format_list(sort_exprs) + )?; + } + Ok(()) } } } @@ -576,8 +586,8 @@ impl ExecutionPlan for RepartitionExec { .collect::>(); // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. - // Get existing ordering: - let sort_exprs = self.input.output_ordering().unwrap_or(&[]); + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]); // Merge streams (while preserving ordering) coming from // input partitions to this partition: @@ -646,6 +656,15 @@ impl RepartitionExec { self } + /// Return the sort expressions that are used to merge + fn sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + if self.preserve_order { + self.input.output_ordering() + } else { + None + } + } + /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 67685509abe5..e60baf2cd806 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -21,84 +21,21 @@ use crate::metrics::BaselineMetrics; use crate::sorts::builder::BatchBuilder; use crate::sorts::cursor::Cursor; -use crate::sorts::stream::{FieldCursorStream, PartitionedStream, RowCursorStream}; -use crate::{PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream}; -use arrow::datatypes::{DataType, SchemaRef}; +use crate::sorts::stream::PartitionedStream; +use crate::RecordBatchStream; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow_array::*; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; -macro_rules! primitive_merge_helper { - ($t:ty, $($v:ident),+) => { - merge_helper!(PrimitiveArray<$t>, $($v),+) - }; -} - -macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ - let streams = FieldCursorStream::<$t>::new($sort, $streams); - return Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - $schema, - $tracking_metrics, - $batch_size, - $fetch, - $reservation, - ))); - }}; -} - -/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions -/// while preserving order. -pub fn streaming_merge( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, - fetch: Option, - reservation: MemoryReservation, -) -> Result { - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - _ => {} - } - } - - let streams = RowCursorStream::try_new( - schema.as_ref(), - expressions, - streams, - reservation.new_empty(), - )?; - - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - fetch, - reservation, - ))) -} - /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, /// The sorted input streams to merge together @@ -162,7 +99,7 @@ struct SortPreservingMergeStream { } impl SortPreservingMergeStream { - fn new( + pub(crate) fn new( streams: CursorStream, schema: SchemaRef, metrics: BaselineMetrics, diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index dff39db423f0..8a1184d3c2b5 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,10 +20,11 @@ mod builder; mod cursor; mod index; -pub mod merge; +mod merge; pub mod sort; pub mod sort_preserving_merge; mod stream; +pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use merge::streaming_merge; +pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 703f80d90d2b..ffc4ef9dc32a 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -24,7 +24,7 @@ use crate::expressions::PhysicalSortExpr; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::sorts::merge::streaming_merge; +use crate::sorts::streaming_merge::streaming_merge; use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::topk::TopK; use crate::{ @@ -763,17 +763,16 @@ impl DisplayAs for SortExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); + let expr = PhysicalSortExpr::format_list(&self.expr); match self.fetch { Some(fetch) => { write!( f, // TODO should this say topk? - "SortExec: fetch={fetch}, expr=[{}]", - expr.join(",") + "SortExec: fetch={fetch}, expr=[{expr}]", ) } - None => write!(f, "SortExec: expr=[{}]", expr.join(",")), + None => write!(f, "SortExec: expr=[{expr}]"), } } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 5b485e0b68e4..597b59f776d5 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -118,8 +118,11 @@ impl DisplayAs for SortPreservingMergeExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?; + write!( + f, + "SortPreservingMergeExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + )?; if let Some(fetch) = self.fetch { write!(f, ", fetch={fetch}")?; }; diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs new file mode 100644 index 000000000000..96d180027eee --- /dev/null +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Merge that deals with an arbitrary size of streaming inputs. +//! This is an order-preserving merge. + +use crate::metrics::BaselineMetrics; +use crate::sorts::{ + merge::SortPreservingMergeStream, + stream::{FieldCursorStream, RowCursorStream}, +}; +use crate::{PhysicalSortExpr, SendableRecordBatchStream}; +use arrow::datatypes::{DataType, SchemaRef}; +use arrow_array::*; +use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; + +macro_rules! primitive_merge_helper { + ($t:ty, $($v:ident),+) => { + merge_helper!(PrimitiveArray<$t>, $($v),+) + }; +} + +macro_rules! merge_helper { + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + let streams = FieldCursorStream::<$t>::new($sort, $streams); + return Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + $schema, + $tracking_metrics, + $batch_size, + $fetch, + $reservation, + ))); + }}; +} + +/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions +/// while preserving order. +pub fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: MemoryReservation, +) -> Result { + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + metrics, + batch_size, + fetch, + reservation, + ))) +} diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index dfef0ddefa03..800ea42b3562 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -88,8 +88,6 @@ pub struct BoundedWindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics @@ -110,11 +108,10 @@ impl BoundedWindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, partition_search_mode: PartitionSearchMode, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); let ordered_partition_by_indices = match &partition_search_mode { @@ -140,7 +137,6 @@ impl BoundedWindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), partition_search_mode, @@ -158,11 +154,6 @@ impl BoundedWindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys @@ -303,7 +294,6 @@ impl ExecutionPlan for BoundedWindowAggExec { Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), self.partition_search_mode.clone(), )?)) @@ -333,7 +323,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn statistics(&self) -> Statistics { let input_stat = self.input.statistics(); let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); if let Some(input_col_stats) = input_stat.column_statistics { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 0f165f79354e..cc915e54af60 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -421,7 +421,6 @@ pub fn get_best_fitting_window( Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_expr, input.clone(), - input.schema(), physical_partition_keys.to_vec(), partition_search_mode, )?) as _)) @@ -435,7 +434,6 @@ pub fn get_best_fitting_window( Ok(Some(Arc::new(WindowAggExec::try_new( window_expr, input.clone(), - input.schema(), physical_partition_keys.to_vec(), )?) as _)) } @@ -759,7 +757,6 @@ mod tests { schema.as_ref(), )?], blocking_exec, - schema, vec![], )?); diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index b56a9c194c8f..b4dc8ec88c68 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -59,8 +59,6 @@ pub struct WindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics @@ -75,10 +73,9 @@ impl WindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let ordered_partition_by_indices = @@ -87,7 +84,6 @@ impl WindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), ordered_partition_by_indices, @@ -104,11 +100,6 @@ impl WindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys @@ -230,7 +221,6 @@ impl ExecutionPlan for WindowAggExec { Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), )?)) } @@ -259,7 +249,7 @@ impl ExecutionPlan for WindowAggExec { fn statistics(&self) -> Statistics { let input_stat = self.input.statistics(); let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); if let Some(input_col_stats) = input_stat.column_statistics { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bda0f7828726..c60dae71ef86 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1419,7 +1419,6 @@ message PartiallySortedPartitionSearchMode { message WindowAggExecNode { PhysicalPlanNode input = 1; repeated PhysicalWindowExprNode window_expr = 2; - Schema input_schema = 4; repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. oneof partition_search_mode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index ced0c8bd7c7a..266075e68922 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -23969,9 +23969,6 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { len += 1; } - if self.input_schema.is_some() { - len += 1; - } if !self.partition_keys.is_empty() { len += 1; } @@ -23985,9 +23982,6 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { struct_ser.serialize_field("windowExpr", &self.window_expr)?; } - if let Some(v) = self.input_schema.as_ref() { - struct_ser.serialize_field("inputSchema", v)?; - } if !self.partition_keys.is_empty() { struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } @@ -24017,8 +24011,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { "input", "window_expr", "windowExpr", - "input_schema", - "inputSchema", "partition_keys", "partitionKeys", "linear", @@ -24031,7 +24023,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { enum GeneratedField { Input, WindowExpr, - InputSchema, PartitionKeys, Linear, PartiallySorted, @@ -24059,7 +24050,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { match value { "input" => Ok(GeneratedField::Input), "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "inputSchema" | "input_schema" => Ok(GeneratedField::InputSchema), "partitionKeys" | "partition_keys" => Ok(GeneratedField::PartitionKeys), "linear" => Ok(GeneratedField::Linear), "partiallySorted" | "partially_sorted" => Ok(GeneratedField::PartiallySorted), @@ -24085,7 +24075,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { { let mut input__ = None; let mut window_expr__ = None; - let mut input_schema__ = None; let mut partition_keys__ = None; let mut partition_search_mode__ = None; while let Some(k) = map_.next_key()? { @@ -24102,12 +24091,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { } window_expr__ = Some(map_.next_value()?); } - GeneratedField::InputSchema => { - if input_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("inputSchema")); - } - input_schema__ = map_.next_value()?; - } GeneratedField::PartitionKeys => { if partition_keys__.is_some() { return Err(serde::de::Error::duplicate_field("partitionKeys")); @@ -24140,7 +24123,6 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { Ok(WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), - input_schema: input_schema__, partition_keys: partition_keys__.unwrap_or_default(), partition_search_mode: partition_search_mode__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ca20cd35cb55..894afa570fb0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1994,8 +1994,6 @@ pub struct WindowAggExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] pub window_expr: ::prost::alloc::vec::Vec, - #[prost(message, optional, tag = "4")] - pub input_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub partition_keys: ::prost::alloc::vec::Vec, /// Set optional to `None` for `BoundedWindowAggExec`. diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8257f9aa3458..08010a3151ee 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -282,16 +282,7 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; - let input_schema = window_agg - .input_schema - .as_ref() - .ok_or_else(|| { - DataFusionError::Internal( - "input_schema in WindowAggrNode is missing.".to_owned(), - ) - })? - .clone(); - let input_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); + let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg .window_expr @@ -333,7 +324,6 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(BoundedWindowAggExec::try_new( physical_window_expr, input, - input_schema, partition_keys, partition_search_mode, )?)) @@ -341,7 +331,6 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(WindowAggExec::try_new( physical_window_expr, input, - input_schema, partition_keys, )?)) } @@ -1315,8 +1304,6 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?; - let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; - let window_expr = exec.window_expr() .iter() @@ -1334,7 +1321,6 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::WindowAggExecNode { input: Some(Box::new(input)), window_expr, - input_schema: Some(input_schema), partition_keys, partition_search_mode: None, }, @@ -1346,8 +1332,6 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?; - let input_schema = protobuf::Schema::try_from(exec.input_schema().as_ref())?; - let window_expr = exec.window_expr() .iter() @@ -1385,7 +1369,6 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::WindowAggExecNode { input: Some(Box::new(input)), window_expr, - input_schema: Some(input_schema), partition_keys, partition_search_mode: Some(partition_search_mode), }, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 77e77630bcb2..e30d416bdc95 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -275,7 +275,6 @@ fn roundtrip_window() -> Result<()> { sliding_aggr_window_expr, ], input, - schema.clone(), vec![col("b", &schema)?], )?)) } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 0113f337e6dc..b119672eae5f 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -132,12 +132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse ON expression let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?; LogicalPlanBuilder::from(left) - .join( - right, - join_type, - (Vec::::new(), Vec::::new()), - Some(expr), - )? + .join_on(right, join_type, Some(expr))? .build() } JoinConstraint::Using(idents) => { diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index c7fabffa0aa0..454f99942f52 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -16,46 +16,46 @@ # under the License. [package] -authors.workspace = true -edition.workspace = true -homepage.workspace = true -license.workspace = true +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } name = "datafusion-sqllogictest" -readme.workspace = true -repository.workspace = true -rust-version.workspace = true -version.workspace = true +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } [lib] name = "datafusion_sqllogictest" path = "src/lib.rs" [dependencies] -arrow = {workspace = true} +arrow = { workspace = true } async-trait = "0.1.41" bigdecimal = "0.4.1" -datafusion = {path = "../core", version = "32.0.0"} -datafusion-common = {path = "../common", version = "32.0.0", default-features = false} +bytes = { version = "1.4.0", optional = true } +chrono = { workspace = true, optional = true } +datafusion = { path = "../core", version = "32.0.0" } +datafusion-common = { path = "../common", version = "32.0.0", default-features = false } +futures = { version = "0.3.28" } half = "2.2.1" itertools = "0.11" -object_store = "0.7.0" -rust_decimal = {version = "1.27.0"} log = "^0.4" +object_store = "0.7.0" +postgres-protocol = { version = "0.6.4", optional = true } +postgres-types = { version = "0.2.4", optional = true } +rust_decimal = { version = "1.27.0" } sqllogictest = "0.17.0" -sqlparser.workspace = true +sqlparser = { workspace = true } tempfile = "3" thiserror = "1.0.44" -tokio = {version = "1.0"} -bytes = {version = "1.4.0", optional = true} -futures = {version = "0.3.28"} -chrono = { workspace = true, optional = true } -tokio-postgres = {version = "0.7.7", optional = true} -postgres-types = {version = "0.2.4", optional = true} -postgres-protocol = {version = "0.6.4", optional = true} +tokio = { version = "1.0" } +tokio-postgres = { version = "0.7.7", optional = true } [features] -postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] avro = ["datafusion/avro"] +postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] [dev-dependencies] env_logger = "0.10" diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index b11a687d8b9f..cbb1896efb13 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -186,8 +186,7 @@ Bob_new John John_new -# should be un-nested -# https://github.com/apache/arrow-datafusion/issues/7786 +# should be un-nested, with a single (logical) aggregate query TT EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) ---- @@ -195,26 +194,19 @@ logical_plan Aggregate: groupBy=[[t1.name]], aggr=[[]] --Union ----TableScan: t1 projection=[name] -----Aggregate: groupBy=[[t2.name]], aggr=[[]] -------Union ---------TableScan: t2 projection=[name] ---------Projection: t2.name || Utf8("_new") AS name -----------TableScan: t2 projection=[name] +----TableScan: t2 projection=[name] +----Projection: t2.name || Utf8("_new") AS name +------TableScan: t2 projection=[name] physical_plan AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] --CoalesceBatchesExec: target_batch_size=8192 -----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8 +----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=12 ------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] --------UnionExec ----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -----------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=8 -----------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] -------------------UnionExec ---------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------------------ProjectionExec: expr=[name@0 || _new as name] -----------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----------ProjectionExec: expr=[name@0 || _new as name] +------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] # nested_union_all query T rowsort diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 5fb5a04c6709..80e496a336f4 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3245,17 +3245,17 @@ physical_plan ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] --BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] ----CoalesceBatchesExec: target_batch_size=4096 -------SortPreservingRepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2 +------SortPreservingRepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b ASC NULLS LAST,c ASC NULLS LAST --------ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] ----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------CoalesceBatchesExec: target_batch_size=4096 ---------------SortPreservingRepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2 +--------------SortPreservingRepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] ------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2 +--------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------------------CoalesceBatchesExec: target_batch_size=4096 ---------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2 +--------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST ----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 13b9c02cdd28..44a13cdcac37 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -33,9 +33,9 @@ chrono = { workspace = true } datafusion = { version = "32.0.0", path = "../core" } itertools = "0.11" object_store = "0.7.0" -prost = "0.11" -prost-types = "0.11" -substrait = "0.16.0" +prost = "0.12" +prost-types = "0.12" +substrait = "0.17.0" tokio = "1.17" [features] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e1dde39427a5..ae65a2c7d94a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -18,6 +18,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; + use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, @@ -129,6 +130,51 @@ fn scalar_function_type_from_str(name: &str) -> Result { } } +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) +} + /// Convert Substrait Plan to DataFusion DataFrame pub async fn from_substrait_plan( ctx: &mut SessionContext, @@ -336,7 +382,13 @@ pub async fn from_substrait_rel( } } Some(RelType::Join(join)) => { - let left = LogicalPlanBuilder::from( + if join.post_join_filter.is_some() { + return not_impl_err!( + "JoinRel with post_join_filter is not yet supported" + ); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, ); let right = LogicalPlanBuilder::from( @@ -346,65 +398,32 @@ pub async fn from_substrait_rel( // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins let in_join_schema = left.schema().join(right.schema())?; - // Parse post join filter if exists - let join_filter = match &join.post_join_filter { - Some(filter) => { - let parsed_filter = - from_substrait_rex(filter, &in_join_schema, extensions).await?; - Some(parsed_filter.as_ref().clone()) - } - None => None, - }; + // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with koin filter, without join keys + // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { let on = from_substrait_rex(expr, &in_join_schema, extensions).await?; - let predicates = split_conjunction(&on); - // TODO: collect only one null_eq_null - let join_exprs: Vec<(Column, Column, bool)> = predicates - .iter() - .map(|p| match p { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => match op { - Operator::Eq => Ok((l.clone(), r.clone(), false)), - Operator::IsNotDistinctFrom => { - Ok((l.clone(), r.clone(), true)) - } - _ => plan_err!("invalid join condition op"), - }, - _ => plan_err!("invalid join condition expression"), - } - } - _ => plan_err!( - "Non-binary expression is not supported in join condition" - ), - }) - .collect::>>()?; - let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = - itertools::multiunzip(join_exprs); + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); left.join_detailed( right.build()?, join_type, (left_cols, right_cols), join_filter, - null_eq_nulls[0], + nulls_equal_nulls, )? .build() } - None => match &join_filter { - Some(_) => left - .join( - right.build()?, - join_type, - (Vec::::new(), Vec::::new()), - join_filter, - )? - .build(), - None => plan_err!("Join without join keys require a valid filter"), - }, + None => plan_err!("JoinRel without join condition is not allowed"), } } Some(RelType::Read(read)) => match &read.as_ref().read_type { @@ -461,8 +480,8 @@ pub async fn from_substrait_rel( } _ => not_impl_err!("Only NamedTable reads are supported"), }, - Some(RelType::Set(set)) => match set_rel::SetOp::from_i32(set.op) { - Some(set_op) => match set_op { + Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { + Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { if !set.inputs.is_empty() { let mut union_builder = Ok(LogicalPlanBuilder::from( @@ -479,7 +498,7 @@ pub async fn from_substrait_rel( } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), }, - None => not_impl_err!("Invalid set operation type None"), + Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -535,7 +554,7 @@ pub async fn from_substrait_rel( } fn from_substrait_jointype(join_type: i32) -> Result { - if let Some(substrait_join_type) = join_rel::JoinType::from_i32(join_type) { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { match substrait_join_type { join_rel::JoinType::Inner => Ok(JoinType::Inner), join_rel::JoinType::Left => Ok(JoinType::Left), @@ -563,7 +582,7 @@ pub async fn from_substrait_sorts( let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { - let Some(direction) = SortDirection::from_i32(*d) else { + let Ok(direction) = SortDirection::try_from(*d) else { return not_impl_err!( "Unsupported Substrait SortDirection value {d}" ); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1124ea53a557..757bddf9fe58 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -278,14 +278,15 @@ pub fn to_substrait_rel( // parse filter if exists let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { - Some(filter) => Some(Box::new(to_substrait_rex( + Some(filter) => Some(to_substrait_rex( filter, &Arc::new(in_join_schema), 0, extension_info, - )?)), + )?), None => None, }; + // map the left and right columns to binary expressions in the form `l = r` // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` let eq_op = if join.null_equals_null { @@ -293,15 +294,31 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - - let join_expr = to_substrait_join_expr( + let join_on = to_substrait_join_expr( &join.on, eq_op, join.left.schema(), join.right.schema(), extension_info, - )? - .map(Box::new); + )?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + on_expr, + filter, + Operator::And, + extension_info, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; Ok(Box::new(Rel { rel_type: Some(RelType::Join(Box::new(JoinRel { @@ -309,8 +326,8 @@ pub fn to_substrait_rel( left: Some(left), right: Some(right), r#type: join_type as i32, - expression: join_expr, - post_join_filter: join_filter, + expression: join_expr.clone(), + post_join_filter: None, advanced_extension: None, }))), })) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 9b9afa159c20..32416125de24 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -23,15 +23,18 @@ use std::hash::Hash; use std::sync::Arc; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use datafusion::common::{DFSchema, DFSchemaRef}; -use datafusion::error::Result; +use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; + use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::rel::RelType; +use substrait::proto::{plan_rel, Plan, Rel}; struct MockSerializerRegistry; @@ -383,12 +386,15 @@ async fn roundtrip_inner_join() -> Result<()> { #[tokio::test] async fn roundtrip_non_equi_inner_join() -> Result<()> { - roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await + roundtrip_verify_post_join_filter( + "SELECT data.a FROM data JOIN data2 ON data.a <> data2.a", + ) + .await } #[tokio::test] async fn roundtrip_non_equi_join() -> Result<()> { - roundtrip( + roundtrip_verify_post_join_filter( "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a", ) .await @@ -620,6 +626,91 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +fn check_post_join_filters(rel: &Rel) -> Result<()> { + // search for target_rel and field value in proto + match &rel.rel_type { + Some(RelType::Join(join)) => { + // check if join filter is None + if join.post_join_filter.is_some() { + plan_err!( + "DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel" + ) + } else { + // recursively check JoinRels + match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) { + Err(e) => Err(e), + Ok(_) => { + check_post_join_filters(join.right.as_ref().unwrap().as_ref()) + } + } + } + } + Some(RelType::Project(p)) => { + check_post_join_filters(p.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Filter(filter)) => { + check_post_join_filters(filter.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Fetch(fetch)) => { + check_post_join_filters(fetch.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Sort(sort)) => { + check_post_join_filters(sort.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Aggregate(agg)) => { + check_post_join_filters(agg.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Set(set)) => { + for input in &set.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionSingle(ext)) => { + check_post_join_filters(ext.input.as_ref().unwrap().as_ref()) + } + Some(RelType::ExtensionMulti(ext)) => { + for input in &ext.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), + _ => not_impl_err!( + "Unsupported RelType: {:?} in post join filter check", + rel.rel_type + ), + } +} + +async fn verify_post_join_filter_value(proto: Box) -> Result<()> { + for relation in &proto.relations { + match relation.rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) { + Err(e) => return Err(e), + Ok(_) => continue, + }, + plan_rel::RelType::Root(root) => { + match check_post_join_filters(root.input.as_ref().unwrap()) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + }, + None => return plan_err!("Cannot parse plan relation: None"), + } + } + + Ok(()) +} + async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { let mut ctx = create_context().await?; let df = ctx.sql(sql).await?; @@ -688,6 +779,25 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { + let mut ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&mut ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + + // verify that the join filters are None + verify_post_join_filter_value(proto).await +} + async fn roundtrip_all_types(sql: &str) -> Result<()> { let mut ctx = create_all_type_context().await?; let df = ctx.sql(sql).await?;