diff --git a/.gitattributes b/.gitattributes index bcdeffc09a11..84b47a6fc56e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ .github/ export-ignore +datafusion/core/tests/data/newlines_in_values.csv text eol=lf datafusion/proto/src/generated/prost.rs linguist-generated datafusion/proto/src/generated/pbjson.rs linguist-generated diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml new file mode 100644 index 000000000000..aa96d55a0d85 --- /dev/null +++ b/.github/workflows/large_files.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Large files PR check + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + pull_request: + +jobs: + check-files: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Check size of new Git objects + env: + # 1 MB ought to be enough for anybody. + # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label + MAX_FILE_SIZE_BYTES: 1048576 + shell: bash + run: | + git rev-list --objects ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} \ + > pull-request-objects.txt + exit_code=0 + while read -r id path; do + # Skip objects which are not files (commits, trees) + if [ ! -z "${path}" ]; then + size="$(git cat-file -s "${id}")" + if [ "${size}" -gt "${MAX_FILE_SIZE_BYTES}" ]; then + exit_code=1 + echo "Object ${id} [${path}] has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." >&2 + echo "::error file=${path}::File ${path} has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." + fi + fi + done < pull-request-objects.txt + exit "${exit_code}" diff --git a/Cargo.toml b/Cargo.toml index f61ed7e58fe3..24bde78b3001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ members = [ "datafusion/optimizer", "datafusion/physical-expr-common", "datafusion/physical-expr", + "datafusion/physical-optimizer", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", @@ -97,6 +98,7 @@ datafusion-functions-array = { path = "datafusion/functions-array", version = "4 datafusion-optimizer = { path = "datafusion/optimizer", version = "40.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "40.0.0", default-features = false } datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "40.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "40.0.0" } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "40.0.0" } datafusion-proto = { path = "datafusion/proto", version = "40.0.0" } datafusion-proto-common = { path = "datafusion/proto-common", version = "40.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e48c6b081e1a..84bff8c87190 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -118,9 +118,9 @@ dependencies = [ [[package]] name = "arrayref" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" +checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" [[package]] name = "arrayvec" @@ -387,7 +387,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -772,9 +772,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.1" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cca6d3674597c30ddf2c587bf8d9d65c9a84d2326d941cc79c9842dfe0ef52" +checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" dependencies = [ "arrayref", "arrayvec", @@ -838,9 +838,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "bytes-utils" @@ -875,13 +875,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.0" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" +checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" dependencies = [ "jobserver", "libc", - "once_cell", ] [[package]] @@ -1105,7 +1104,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1154,6 +1153,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", + "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", "flate2", @@ -1331,6 +1331,7 @@ dependencies = [ "itertools", "log", "paste", + "rand", ] [[package]] @@ -1391,6 +1392,16 @@ dependencies = [ "rand", ] +[[package]] +name = "datafusion-physical-optimizer" +version = "40.0.0" +dependencies = [ + "datafusion-common", + "datafusion-execution", + "datafusion-physical-expr", + "datafusion-physical-plan", +] + [[package]] name = "datafusion-physical-plan" version = "40.0.0" @@ -1694,7 +1705,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1908,9 +1919,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1925,7 +1936,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1982,7 +1993,7 @@ dependencies = [ "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "itoa", "pin-project-lite", @@ -2034,7 +2045,7 @@ dependencies = [ "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "hyper 1.4.1", "pin-project-lite", "socket2", @@ -2707,7 +2718,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -2917,9 +2928,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ "bitflags 2.6.0", ] @@ -2982,7 +2993,7 @@ dependencies = [ "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "hyper 1.4.1", "hyper-rustls 0.27.2", @@ -3269,9 +3280,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3282,9 +3293,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3319,7 +3330,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3454,7 +3465,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3500,7 +3511,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3513,7 +3524,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3535,9 +3546,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.70" +version = "2.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" +checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462" dependencies = [ "proc-macro2", "quote", @@ -3585,22 +3596,22 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3670,9 +3681,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.0" +version = "1.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" dependencies = [ "backtrace", "bytes", @@ -3695,7 +3706,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3792,7 +3803,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3837,7 +3848,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3991,7 +4002,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", "wasm-bindgen-shared", ] @@ -4025,7 +4036,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4290,7 +4301,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index bdb702375c94..8612a1cc4430 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -131,7 +131,7 @@ impl FileFormat for TSVFileFormat { } } -#[derive(Default)] +#[derive(Default, Debug)] /// Factory for creating TSV file formats /// /// This factory is a wrapper around the CSV file format factory @@ -166,6 +166,10 @@ impl FileFormatFactory for TSVFileFactory { fn default(&self) -> std::sync::Arc { todo!() } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for TSVFileFactory { diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a5cf7011f811..a48171c625a8 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -177,16 +177,12 @@ fn simplify_demo() -> Result<()> { ); // here are some other examples of what DataFusion is capable of - let schema = Schema::new(vec![ - make_field("i", DataType::Int64), - make_field("b", DataType::Boolean), - ]) - .to_dfschema_ref()?; + let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification - // i + 1 + 2 => a + 3 + // i + 1 + 2 => i + 3 // (note this is not done if the expr is (col("i") + (lit(1) + lit(2)))) assert_eq!( simplifier.simplify(col("i") + (lit(1) + lit(2)))?, @@ -209,7 +205,7 @@ fn simplify_demo() -> Result<()> { ); // String --> Date simplification - // `cast('2020-09-01' as date)` --> 18500 + // `cast('2020-09-01' as date)` --> 18506 # number of days since epoch 1970-01-01 assert_eq!( simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, lit(ScalarValue::Date32(Some(18506))) diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index e36a4f890644..2e2bfff40340 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -214,7 +214,7 @@ impl Column { for using_col in using_columns { let all_matched = columns.iter().all(|f| using_col.contains(f)); // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. + // the same join clause. We simply pick the qualifier from the first match. if all_matched { return Ok(columns[0].clone()); } @@ -303,7 +303,7 @@ impl Column { for using_col in using_columns { let all_matched = columns.iter().all(|c| using_col.contains(c)); // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. + // the same join clause. We simply pick the qualifier from the first match. if all_matched { return Ok(columns[0].clone()); } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1c46e103abcd..7f7c377e8386 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -184,6 +184,16 @@ config_namespace! { /// Default value for `format.has_header` for `CREATE EXTERNAL TABLE` /// if not specified explicitly in the statement. pub has_header: bool, default = false + + /// Specifies whether newlines in (quoted) CSV values are supported. + /// + /// This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + pub newlines_in_values: bool, default = false } } @@ -315,93 +325,99 @@ config_namespace! { } config_namespace! { - /// Options related to parquet files + /// Options for reading and writing parquet files /// /// See also: [`SessionConfig`] /// /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct ParquetOptions { - /// If true, reads the Parquet data page level metadata (the + // The following options affect reading parquet files + + /// (reading) If true, reads the Parquet data page level metadata (the /// Page Index), if present, to reduce the I/O and number of /// rows decoded. pub enable_page_index: bool, default = true - /// If true, the parquet reader attempts to skip entire row groups based + /// (reading) If true, the parquet reader attempts to skip entire row groups based /// on the predicate in the query and the metadata (min/max values) stored in /// the parquet file pub pruning: bool, default = true - /// If true, the parquet reader skip the optional embedded metadata that may be in + /// (reading) If true, the parquet reader skip the optional embedded metadata that may be in /// the file Schema. This setting can help avoid schema conflicts when querying /// multiple parquet files with schemas containing compatible types but different metadata pub skip_metadata: bool, default = true - /// If specified, the parquet reader will try and fetch the last `size_hint` + /// (reading) If specified, the parquet reader will try and fetch the last `size_hint` /// bytes of the parquet file optimistically. If not specified, two reads are required: /// One read to fetch the 8-byte parquet footer and /// another to fetch the metadata length encoded in the footer pub metadata_size_hint: Option, default = None - /// If true, filter expressions are be applied during the parquet decoding operation to + /// (reading) If true, filter expressions are be applied during the parquet decoding operation to /// reduce the number of rows decoded. This optimization is sometimes called "late materialization". pub pushdown_filters: bool, default = false - /// If true, filter expressions evaluated during the parquet decoding operation + /// (reading) If true, filter expressions evaluated during the parquet decoding operation /// will be reordered heuristically to minimize the cost of evaluation. If false, /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false - // The following map to parquet::file::properties::WriterProperties + // The following options affect writing to parquet files + // and map to parquet::file::properties::WriterProperties - /// Sets best effort maximum size of data page in bytes + /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in bytes pub write_batch_size: usize, default = 1024 - /// Sets parquet writer version + /// (writing) Sets parquet writer version /// valid values are "1.0" and "2.0" - pub writer_version: String, default = "1.0".into() + pub writer_version: String, default = "1.0".to_string() - /// Sets default parquet compression codec + /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting + /// + /// Note that this default setting is not the same as + /// the default parquet writer setting. pub compression: Option, default = Some("zstd(3)".into()) - /// Sets if dictionary encoding is enabled. If NULL, uses + /// (writing) Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting - pub dictionary_enabled: Option, default = None + pub dictionary_enabled: Option, default = Some(true) - /// Sets best effort maximum dictionary page size, in bytes + /// (writing) Sets best effort maximum dictionary page size, in bytes pub dictionary_page_size_limit: usize, default = 1024 * 1024 - /// Sets if statistics are enabled for any column + /// (writing) Sets if statistics are enabled for any column /// Valid values are: "none", "chunk", and "page" /// These values are not case sensitive. If NULL, uses /// default parquet writer setting pub statistics_enabled: Option, default = None - /// Sets max statistics size for any column. If NULL, uses + /// (writing) Sets max statistics size for any column. If NULL, uses /// default parquet writer setting - pub max_statistics_size: Option, default = None + pub max_statistics_size: Option, default = Some(4096) - /// Target maximum number of rows in each row group (defaults to 1M + /// (writing) Target maximum number of rows in each row group (defaults to 1M /// rows). Writing larger row groups requires more memory to write, but /// can get better compression and be faster to read. - pub max_row_group_size: usize, default = 1024 * 1024 + pub max_row_group_size: usize, default = 1024 * 1024 - /// Sets "created by" property + /// (writing) Sets "created by" property pub created_by: String, default = concat!("datafusion version ", env!("CARGO_PKG_VERSION")).into() - /// Sets column index truncate length - pub column_index_truncate_length: Option, default = None + /// (writing) Sets column index truncate length + pub column_index_truncate_length: Option, default = Some(64) - /// Sets best effort maximum number of rows in data page - pub data_page_row_count_limit: usize, default = usize::MAX + /// (writing) Sets best effort maximum number of rows in data page + pub data_page_row_count_limit: usize, default = 20_000 - /// Sets default encoding for any column + /// (writing) Sets default encoding for any column. /// Valid values are: plain, plain_dictionary, rle, /// bit_packed, delta_binary_packed, delta_length_byte_array, /// delta_byte_array, rle_dictionary, and byte_stream_split. @@ -409,27 +425,27 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, default = None - /// Use any available bloom filters when reading parquet files + /// (writing) Use any available bloom filters when reading parquet files pub bloom_filter_on_read: bool, default = true - /// Write bloom filters for all columns when creating parquet files + /// (writing) Write bloom filters for all columns when creating parquet files pub bloom_filter_on_write: bool, default = false - /// Sets bloom filter false positive probability. If NULL, uses + /// (writing) Sets bloom filter false positive probability. If NULL, uses /// default parquet writer setting pub bloom_filter_fpp: Option, default = None - /// Sets bloom filter number of distinct values. If NULL, uses + /// (writing) Sets bloom filter number of distinct values. If NULL, uses /// default parquet writer setting pub bloom_filter_ndv: Option, default = None - /// Controls whether DataFusion will attempt to speed up writing + /// (writing) Controls whether DataFusion will attempt to speed up writing /// parquet files by serializing them in parallel. Each column /// in each row group in each output file are serialized in parallel /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. pub allow_single_file_parallelism: bool, default = true - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -440,7 +456,7 @@ config_namespace! { /// data frame. pub maximum_parallel_row_group_writers: usize, default = 1 - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -450,7 +466,6 @@ config_namespace! { /// writing out already in-memory data, such as from a cached /// data frame. pub maximum_buffered_record_batches_per_stream: usize, default = 2 - } } @@ -1534,6 +1549,9 @@ macro_rules! config_namespace_with_hashmap { } config_namespace_with_hashmap! { + /// Options controlling parquet format for individual columns. + /// + /// See [`ParquetOptions`] for more details pub struct ParquetColumnOptions { /// Sets if bloom filter is enabled for the column path. pub bloom_filter_enabled: Option, default = None @@ -1588,6 +1606,14 @@ config_namespace! { pub quote: u8, default = b'"' pub escape: Option, default = None pub double_quote: Option, default = None + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 pub date_format: Option, default = None @@ -1660,6 +1686,18 @@ impl CsvOptions { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 7598cbc4d86a..f0eecd2ffeb1 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -521,34 +521,8 @@ impl DFSchema { /// Find the field with the given name pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { - let matches = self.qualified_fields_with_unqualified_name(name); - match matches.len() { - 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok(matches[0].1), - _ => { - // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. - // Because name may generate from Alias/... . It means that it don't own qualifier. - // For example: - // Join on id = b.id - // Project a.id as id TableScan b id - // In this case, there isn't `ambiguous name` problem. When `matches` just contains - // one field without qualifier, we should return it. - let fields_without_qualifier = matches - .iter() - .filter(|(q, _)| q.is_none()) - .collect::>(); - if fields_without_qualifier.len() == 1 { - Ok(fields_without_qualifier[0].1) - } else { - _schema_err!(SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, - }) - } - } - } + self.qualified_field_with_unqualified_name(name) + .map(|(_, field)| field) } /// Find the field with the given qualified name @@ -1310,7 +1284,7 @@ mod tests { ]) } #[test] - fn test_dfschema_to_schema_convertion() { + fn test_dfschema_to_schema_conversion() { let mut a_metadata = HashMap::new(); a_metadata.insert("key".to_string(), "value".to_string()); let a_field = Field::new("a", DataType::Int64, false).with_metadata(a_metadata); diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 9be662ca283e..58ff1121e36d 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -661,11 +661,16 @@ mod test { assert_eq!(res.strip_backtrace(), "Arrow error: Schema error: bar"); } - // RUST_BACKTRACE=1 cargo test --features backtrace --package datafusion-common --lib -- error::test::test_backtrace + // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] #[allow(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { + match std::env::var("RUST_BACKTRACE") { + Ok(val) if val == "1" => {} + _ => panic!("Environment variable RUST_BACKTRACE must be set to 1"), + }; + let res: Result<(), DataFusionError> = plan_err!("Err"); let err = res.unwrap_err().to_string(); assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 009164a29e34..21d49f701e75 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -24,16 +24,18 @@ use crate::{ use parquet::{ basic::{BrotliLevel, GzipLevel, ZstdLevel}, - file::{ - metadata::KeyValue, - properties::{EnabledStatistics, WriterProperties, WriterVersion}, + file::properties::{ + EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, + DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, }, + format::KeyValue, schema::types::ColumnPath, }; /// Options for writing parquet files #[derive(Clone, Debug)] pub struct ParquetWriterOptions { + /// parquet-rs writer properties pub writer_options: WriterProperties, } @@ -52,92 +54,43 @@ impl ParquetWriterOptions { impl TryFrom<&TableParquetOptions> for ParquetWriterOptions { type Error = DataFusionError; - fn try_from(parquet_options: &TableParquetOptions) -> Result { - let ParquetOptions { - data_pagesize_limit, - write_batch_size, - writer_version, - dictionary_page_size_limit, - max_row_group_size, - created_by, - column_index_truncate_length, - data_page_row_count_limit, - bloom_filter_on_write, - encoding, - dictionary_enabled, - compression, - statistics_enabled, - max_statistics_size, - bloom_filter_fpp, - bloom_filter_ndv, - // below is not part of ParquetWriterOptions - enable_page_index: _, - pruning: _, - skip_metadata: _, - metadata_size_hint: _, - pushdown_filters: _, - reorder_filters: _, - allow_single_file_parallelism: _, - maximum_parallel_row_group_writers: _, - maximum_buffered_record_batches_per_stream: _, - bloom_filter_on_read: _, - } = &parquet_options.global; - - let key_value_metadata = if !parquet_options.key_value_metadata.is_empty() { - Some( - parquet_options - .key_value_metadata - .clone() - .drain() - .map(|(key, value)| KeyValue { key, value }) - .collect::>(), - ) - } else { - None - }; - - let mut builder = WriterProperties::builder() - .set_data_page_size_limit(*data_pagesize_limit) - .set_write_batch_size(*write_batch_size) - .set_writer_version(parse_version_string(writer_version.as_str())?) - .set_dictionary_page_size_limit(*dictionary_page_size_limit) - .set_max_row_group_size(*max_row_group_size) - .set_created_by(created_by.clone()) - .set_column_index_truncate_length(*column_index_truncate_length) - .set_data_page_row_count_limit(*data_page_row_count_limit) - .set_bloom_filter_enabled(*bloom_filter_on_write) - .set_key_value_metadata(key_value_metadata); - - if let Some(encoding) = &encoding { - builder = builder.set_encoding(parse_encoding_string(encoding)?); - } - - if let Some(enabled) = dictionary_enabled { - builder = builder.set_dictionary_enabled(*enabled); - } - - if let Some(compression) = &compression { - builder = builder.set_compression(parse_compression_string(compression)?); - } - - if let Some(statistics) = &statistics_enabled { - builder = - builder.set_statistics_enabled(parse_statistics_string(statistics)?); - } - - if let Some(size) = max_statistics_size { - builder = builder.set_max_statistics_size(*size); - } + fn try_from(parquet_table_options: &TableParquetOptions) -> Result { + // ParquetWriterOptions will have defaults for the remaining fields (e.g. sorting_columns) + Ok(ParquetWriterOptions { + writer_options: WriterPropertiesBuilder::try_from(parquet_table_options)? + .build(), + }) + } +} - if let Some(fpp) = bloom_filter_fpp { - builder = builder.set_bloom_filter_fpp(*fpp); - } +impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { + type Error = DataFusionError; - if let Some(ndv) = bloom_filter_ndv { - builder = builder.set_bloom_filter_ndv(*ndv); + /// Convert the session's [`TableParquetOptions`] into a single write action's [`WriterPropertiesBuilder`]. + /// + /// The returned [`WriterPropertiesBuilder`] includes customizations applicable per column. + fn try_from(table_parquet_options: &TableParquetOptions) -> Result { + // Table options include kv_metadata and col-specific options + let TableParquetOptions { + global, + column_specific_options, + key_value_metadata, + } = table_parquet_options; + + let mut builder = global.into_writer_properties_builder()?; + + if !key_value_metadata.is_empty() { + builder = builder.set_key_value_metadata(Some( + key_value_metadata + .to_owned() + .drain() + .map(|(key, value)| KeyValue { key, value }) + .collect(), + )); } - for (column, options) in &parquet_options.column_specific_options { + // Apply column-specific options: + for (column, options) in column_specific_options { let path = ColumnPath::new(column.split('.').map(|s| s.to_owned()).collect()); if let Some(bloom_filter_enabled) = options.bloom_filter_enabled { @@ -183,10 +136,87 @@ impl TryFrom<&TableParquetOptions> for ParquetWriterOptions { } } - // ParquetWriterOptions will have defaults for the remaining fields (e.g. sorting_columns) - Ok(ParquetWriterOptions { - writer_options: builder.build(), - }) + Ok(builder) + } +} + +impl ParquetOptions { + /// Convert the global session options, [`ParquetOptions`], into a single write action's [`WriterPropertiesBuilder`]. + /// + /// The returned [`WriterPropertiesBuilder`] can then be further modified with additional options + /// applied per column; a customization which is not applicable for [`ParquetOptions`]. + pub fn into_writer_properties_builder(&self) -> Result { + let ParquetOptions { + data_pagesize_limit, + write_batch_size, + writer_version, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_statistics_size, + max_row_group_size, + created_by, + column_index_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + + // not in WriterProperties + enable_page_index: _, + pruning: _, + skip_metadata: _, + metadata_size_hint: _, + pushdown_filters: _, + reorder_filters: _, + allow_single_file_parallelism: _, + maximum_parallel_row_group_writers: _, + maximum_buffered_record_batches_per_stream: _, + bloom_filter_on_read: _, // reads not used for writer props + } = self; + + let mut builder = WriterProperties::builder() + .set_data_page_size_limit(*data_pagesize_limit) + .set_write_batch_size(*write_batch_size) + .set_writer_version(parse_version_string(writer_version.as_str())?) + .set_dictionary_page_size_limit(*dictionary_page_size_limit) + .set_statistics_enabled( + statistics_enabled + .as_ref() + .and_then(|s| parse_statistics_string(s).ok()) + .unwrap_or(DEFAULT_STATISTICS_ENABLED), + ) + .set_max_statistics_size( + max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), + ) + .set_max_row_group_size(*max_row_group_size) + .set_created_by(created_by.clone()) + .set_column_index_truncate_length(*column_index_truncate_length) + .set_data_page_row_count_limit(*data_page_row_count_limit) + .set_bloom_filter_enabled(*bloom_filter_on_write); + + if let Some(bloom_filter_fpp) = bloom_filter_fpp { + builder = builder.set_bloom_filter_fpp(*bloom_filter_fpp); + }; + if let Some(bloom_filter_ndv) = bloom_filter_ndv { + builder = builder.set_bloom_filter_ndv(*bloom_filter_ndv); + }; + if let Some(dictionary_enabled) = dictionary_enabled { + builder = builder.set_dictionary_enabled(*dictionary_enabled); + }; + + // We do not have access to default ColumnProperties set in Arrow. + // Therefore, only overwrite if these settings exist. + if let Some(compression) = compression { + builder = builder.set_compression(parse_compression_string(compression)?); + } + if let Some(encoding) = encoding { + builder = builder.set_encoding(parse_encoding_string(encoding)?); + } + + Ok(builder) } } @@ -336,3 +366,426 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetColumnOptions { + ParquetColumnOptions { + compression: Some("zstd(22)".into()), + dictionary_enabled: src_col_defaults.dictionary_enabled.map(|v| !v), + statistics_enabled: Some("page".into()), + max_statistics_size: Some(72), + encoding: Some("RLE".into()), + bloom_filter_enabled: Some(true), + bloom_filter_fpp: Some(0.72), + bloom_filter_ndv: Some(72), + } + } + + fn parquet_options_with_non_defaults() -> ParquetOptions { + let defaults = ParquetOptions::default(); + let writer_version = if defaults.writer_version.eq("1.0") { + "2.0" + } else { + "1.0" + }; + + ParquetOptions { + data_pagesize_limit: 42, + write_batch_size: 42, + writer_version: writer_version.into(), + compression: Some("zstd(22)".into()), + dictionary_enabled: Some(!defaults.dictionary_enabled.unwrap_or(false)), + dictionary_page_size_limit: 42, + statistics_enabled: Some("chunk".into()), + max_statistics_size: Some(42), + max_row_group_size: 42, + created_by: "wordy".into(), + column_index_truncate_length: Some(42), + data_page_row_count_limit: 42, + encoding: Some("BYTE_STREAM_SPLIT".into()), + bloom_filter_on_write: !defaults.bloom_filter_on_write, + bloom_filter_fpp: Some(0.42), + bloom_filter_ndv: Some(42), + + // not in WriterProperties, but itemizing here to not skip newly added props + enable_page_index: defaults.enable_page_index, + pruning: defaults.pruning, + skip_metadata: defaults.skip_metadata, + metadata_size_hint: defaults.metadata_size_hint, + pushdown_filters: defaults.pushdown_filters, + reorder_filters: defaults.reorder_filters, + allow_single_file_parallelism: defaults.allow_single_file_parallelism, + maximum_parallel_row_group_writers: defaults + .maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream: defaults + .maximum_buffered_record_batches_per_stream, + bloom_filter_on_read: defaults.bloom_filter_on_read, + } + } + + fn extract_column_options( + props: &WriterProperties, + col: ColumnPath, + ) -> ParquetColumnOptions { + let bloom_filter_default_props = props.bloom_filter_properties(&col); + + ParquetColumnOptions { + bloom_filter_enabled: Some(bloom_filter_default_props.is_some()), + encoding: props.encoding(&col).map(|s| s.to_string()), + dictionary_enabled: Some(props.dictionary_enabled(&col)), + compression: match props.compression(&col) { + Compression::ZSTD(lvl) => { + Some(format!("zstd({})", lvl.compression_level())) + } + _ => None, + }, + statistics_enabled: Some( + match props.statistics_enabled(&col) { + EnabledStatistics::None => "none", + EnabledStatistics::Chunk => "chunk", + EnabledStatistics::Page => "page", + } + .into(), + ), + bloom_filter_fpp: bloom_filter_default_props.map(|p| p.fpp), + bloom_filter_ndv: bloom_filter_default_props.map(|p| p.ndv), + max_statistics_size: Some(props.max_statistics_size(&col)), + } + } + + /// For testing only, take a single write's props and convert back into the session config. + /// (use identity to confirm correct.) + fn session_config_from_writer_props(props: &WriterProperties) -> TableParquetOptions { + let default_col = ColumnPath::from("col doesn't have specific config"); + let default_col_props = extract_column_options(props, default_col); + + let configured_col = ColumnPath::from(COL_NAME); + let configured_col_props = extract_column_options(props, configured_col); + + let key_value_metadata = props + .key_value_metadata() + .map(|pairs| { + HashMap::from_iter( + pairs + .iter() + .cloned() + .map(|KeyValue { key, value }| (key, value)), + ) + }) + .unwrap_or_default(); + + let global_options_defaults = ParquetOptions::default(); + + let column_specific_options = if configured_col_props.eq(&default_col_props) { + HashMap::default() + } else { + HashMap::from([(COL_NAME.into(), configured_col_props)]) + }; + + TableParquetOptions { + global: ParquetOptions { + // global options + data_pagesize_limit: props.dictionary_page_size_limit(), + write_batch_size: props.write_batch_size(), + writer_version: format!("{}.0", props.writer_version().as_num()), + dictionary_page_size_limit: props.dictionary_page_size_limit(), + max_row_group_size: props.max_row_group_size(), + created_by: props.created_by().to_string(), + column_index_truncate_length: props.column_index_truncate_length(), + data_page_row_count_limit: props.data_page_row_count_limit(), + + // global options which set the default column props + encoding: default_col_props.encoding, + compression: default_col_props.compression, + dictionary_enabled: default_col_props.dictionary_enabled, + statistics_enabled: default_col_props.statistics_enabled, + max_statistics_size: default_col_props.max_statistics_size, + bloom_filter_on_write: default_col_props + .bloom_filter_enabled + .unwrap_or_default(), + bloom_filter_fpp: default_col_props.bloom_filter_fpp, + bloom_filter_ndv: default_col_props.bloom_filter_ndv, + + // not in WriterProperties + enable_page_index: global_options_defaults.enable_page_index, + pruning: global_options_defaults.pruning, + skip_metadata: global_options_defaults.skip_metadata, + metadata_size_hint: global_options_defaults.metadata_size_hint, + pushdown_filters: global_options_defaults.pushdown_filters, + reorder_filters: global_options_defaults.reorder_filters, + allow_single_file_parallelism: global_options_defaults + .allow_single_file_parallelism, + maximum_parallel_row_group_writers: global_options_defaults + .maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream: global_options_defaults + .maximum_buffered_record_batches_per_stream, + bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, + }, + column_specific_options, + key_value_metadata, + } + } + + #[test] + fn table_parquet_opts_to_writer_props() { + // ParquetOptions, all props set to non-default + let parquet_options = parquet_options_with_non_defaults(); + + // TableParquetOptions, using ParquetOptions for global settings + let key = "foo".to_string(); + let value = Some("bar".into()); + let table_parquet_opts = TableParquetOptions { + global: parquet_options.clone(), + column_specific_options: [( + COL_NAME.into(), + column_options_with_non_defaults(&parquet_options), + )] + .into(), + key_value_metadata: [(key.clone(), value.clone())].into(), + }; + + let writer_props = WriterPropertiesBuilder::try_from(&table_parquet_opts) + .unwrap() + .build(); + assert_eq!( + table_parquet_opts, + session_config_from_writer_props(&writer_props), + "the writer_props should have the same configuration as the session's TableParquetOptions", + ); + } + + /// Ensure that the configuration defaults for writing parquet files are + /// consistent with the options in arrow-rs + #[test] + fn test_defaults_match() { + // ensure the global settings are the same + let default_table_writer_opts = TableParquetOptions::default(); + let default_parquet_opts = ParquetOptions::default(); + assert_eq!( + default_table_writer_opts.global, + default_parquet_opts, + "should have matching defaults for TableParquetOptions.global and ParquetOptions", + ); + + // WriterProperties::default, a.k.a. using extern parquet's defaults + let default_writer_props = WriterProperties::new(); + + // WriterProperties::try_from(TableParquetOptions::default), a.k.a. using datafusion's defaults + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .build(); + + // Expected: how the defaults should not match + assert_ne!( + default_writer_props.created_by(), + from_datafusion_defaults.created_by(), + "should have different created_by sources", + ); + assert!( + default_writer_props.created_by().starts_with("parquet-rs version"), + "should indicate that writer_props defaults came from the extern parquet crate", + ); + assert!( + default_table_writer_opts + .global + .created_by + .starts_with("datafusion version"), + "should indicate that table_parquet_opts defaults came from datafusion", + ); + + // Expected: the remaining should match + let same_created_by = default_table_writer_opts.global.created_by.clone(); + let mut from_extern_parquet = + session_config_from_writer_props(&default_writer_props); + from_extern_parquet.global.created_by = same_created_by; + // TODO: the remaining defaults do not match! + // refer to https://github.com/apache/datafusion/issues/11367 + assert_ne!( + default_table_writer_opts, + from_extern_parquet, + "the default writer_props should have the same configuration as the session's default TableParquetOptions", + ); + + // Below here itemizes how the defaults **should** match, but do not. + + // TODO: compression defaults do not match + // refer to https://github.com/apache/datafusion/issues/11367 + assert_eq!( + default_writer_props.compression(&"default".into()), + Compression::UNCOMPRESSED, + "extern parquet's default is None" + ); + assert!( + matches!( + from_datafusion_defaults.compression(&"default".into()), + Compression::ZSTD(_) + ), + "datafusion's default is zstd" + ); + + // datafusion's `None` for Option => becomes parquet's EnabledStatistics::Page + // TODO: should this be changed? + // refer to https://github.com/apache/datafusion/issues/11367 + assert_eq!( + default_writer_props.statistics_enabled(&"default".into()), + EnabledStatistics::Page, + "extern parquet's default is page" + ); + assert_eq!( + default_table_writer_opts.global.statistics_enabled, None, + "datafusion's has no default" + ); + assert_eq!( + from_datafusion_defaults.statistics_enabled(&"default".into()), + EnabledStatistics::Page, + "should see the extern parquet's default over-riding datafusion's None", + ); + + // Confirm all other settings are equal. + // First resolve the known discrepancies, (set as the same). + // TODO: once we fix the above mis-matches, we should be able to remove this. + let mut from_extern_parquet = + session_config_from_writer_props(&default_writer_props); + from_extern_parquet.global.compression = Some("zstd(3)".into()); + from_extern_parquet.global.statistics_enabled = None; + + // Expected: the remaining should match + let same_created_by = default_table_writer_opts.global.created_by.clone(); // we expect these to be different + from_extern_parquet.global.created_by = same_created_by; // we expect these to be different + assert_eq!( + default_table_writer_opts, + from_extern_parquet, + "the default writer_props should have the same configuration as the session's default TableParquetOptions", + ); + } + + #[test] + fn test_bloom_filter_defaults() { + // the TableParquetOptions::default, with only the bloom filter turned on + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + + // the WriterProperties::default, with only the bloom filter turned on + let default_writer_props = WriterProperties::new(); + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .set_bloom_filter_enabled(true) + .build(); + + // TODO: should have same behavior in either. + // refer to https://github.com/apache/datafusion/issues/11367 + assert_ne!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, will not have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + None, + "extern parquet's default remains None" + ); + assert_eq!( + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties::default()), + "datafusion's has BloomFilterProperties::default", + ); + } + + #[test] + fn test_bloom_filter_set_fpp_only() { + // the TableParquetOptions::default, with only fpp set + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + default_table_writer_opts.global.bloom_filter_fpp = Some(0.42); + + // the WriterProperties::default, with only fpp set + let default_writer_props = WriterProperties::new(); + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .set_bloom_filter_enabled(true) + .set_bloom_filter_fpp(0.42) + .build(); + + // TODO: should have same behavior in either. + // refer to https://github.com/apache/datafusion/issues/11367 + assert_ne!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, will not have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + None, + "extern parquet's default remains None" + ); + assert_eq!( + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties { + fpp: 0.42, + ndv: DEFAULT_BLOOM_FILTER_NDV + }), + "datafusion's has BloomFilterProperties", + ); + } + + #[test] + fn test_bloom_filter_set_ndv_only() { + // the TableParquetOptions::default, with only ndv set + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + default_table_writer_opts.global.bloom_filter_ndv = Some(42); + + // the WriterProperties::default, with only ndv set + let default_writer_props = WriterProperties::new(); + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .set_bloom_filter_enabled(true) + .set_bloom_filter_ndv(42) + .build(); + + // TODO: should have same behavior in either. + // refer to https://github.com/apache/datafusion/issues/11367 + assert_ne!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, will not have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + None, + "extern parquet's default remains None" + ); + assert_eq!( + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties { + fpp: DEFAULT_BLOOM_FILTER_FPP, + ndv: 42 + }), + "datafusion's has BloomFilterProperties", + ); + } +} diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index d1c3747b52b4..452f1862b274 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -433,7 +433,7 @@ impl FunctionalDependencies { } /// This function ensures that functional dependencies involving uniquely - /// occuring determinant keys cover their entire table in terms of + /// occurring determinant keys cover their entire table in terms of /// dependent columns. pub fn extend_target_indices(&mut self, n_out: usize) { self.deps.iter_mut().for_each( diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index c8adae34f645..010221b0485f 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano; use crate::cast::{ as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, - as_large_list_array, as_list_array, as_primitive_array, as_string_array, - as_struct_array, + as_large_list_array, as_list_array, as_map_array, as_primitive_array, + as_string_array, as_struct_array, }; use crate::error::{Result, _internal_err}; @@ -236,6 +236,40 @@ fn hash_struct_array( Ok(()) } +fn hash_map_array( + array: &MapArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let offsets = array.offsets(); + + // Create hashes for each entry in each row + let mut values_hashes = vec![0u64; array.entries().len()]; + create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + + // Combine the hashes for entries on each row with each other and previous hash for that row + if let Some(nulls) = nulls { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + + Ok(()) +} + fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -400,6 +434,10 @@ pub fn create_hashes<'a>( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::Map(_, _) => { + let array = as_map_array(array)?; + hash_map_array(array, random_state, hashes_buffer)?; + } DataType::FixedSizeList(_,_) => { let array = as_fixed_size_list_array(array)?; hash_fixed_list_array(array, random_state, hashes_buffer)?; @@ -572,6 +610,7 @@ mod tests { Some(vec![Some(3), None, Some(5)]), None, Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), ]; let list_array = Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; @@ -581,6 +620,7 @@ mod tests { assert_eq!(hashes[0], hashes[5]); assert_eq!(hashes[1], hashes[4]); assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[1], hashes[6]); // null vs empty list } #[test] @@ -692,6 +732,64 @@ mod tests { assert_eq!(hashes[0], hashes[1]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_map_arrays() { + let mut builder = + MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // Row 0 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 1 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 2 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(3); + builder.append(true).unwrap(); + // Row 3 + builder.keys().append_value("key1"); + builder.keys().append_value("key3"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 4 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(true).unwrap(); + // Row 5 + builder.keys().append_value("key1"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Row 6 + builder.append(true).unwrap(); + // Row 7 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(false).unwrap(); + + let array = Arc::new(builder.finish()) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); // same value + assert_ne!(hashes[0], hashes[2]); // different value + assert_ne!(hashes[0], hashes[3]); // different key + assert_ne!(hashes[0], hashes[4]); // missing an entry + assert_ne!(hashes[4], hashes[5]); // filled vs null value + assert_eq!(hashes[6], hashes[7]); // empty vs null map + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 38f70e4c1466..92ed897e7185 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1063,7 +1063,6 @@ impl ScalarValue { /// Create an one value in the given type. pub fn new_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(1)), DataType::Int16 => ScalarValue::Int16(Some(1)), @@ -1086,7 +1085,6 @@ impl ScalarValue { /// Create a negative one value in the given type. pub fn new_negative_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)), DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), @@ -1104,7 +1102,6 @@ impl ScalarValue { } pub fn new_ten(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(10)), DataType::Int16 => ScalarValue::Int16(Some(10)), @@ -1773,6 +1770,7 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) + | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; @@ -1841,7 +1839,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index b6ccaa74d5fc..67f3da4f48de 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -62,7 +62,7 @@ impl std::fmt::Display for ResolvedTableReference { /// assert_eq!(table_reference, TableReference::bare("mytable")); /// /// // Get a table reference to 'MyTable' (note the capitalization) using double quotes -/// // (programatically it is better to use `TableReference::bare` for this) +/// // (programmatically it is better to use `TableReference::bare` for this) /// let table_reference = TableReference::from(r#""MyTable""#); /// assert_eq!(table_reference, TableReference::bare("MyTable")); /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 532ca8fde9e7..4301396b231f 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -106,6 +106,7 @@ datafusion-functions-array = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } @@ -216,3 +217,8 @@ name = "topk_aggregate" [[bench]] harness = false name = "parquet_statistic" + +[[bench]] +harness = false +name = "map_query_sql" +required-features = ["array_expressions"] diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs new file mode 100644 index 000000000000..b6ac8b6b647a --- /dev/null +++ b/datafusion/core/benches/map_query_sql.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use parking_lot::Mutex; +use rand::prelude::ThreadRng; +use rand::Rng; +use tokio::runtime::Runtime; + +use datafusion::prelude::SessionContext; +use datafusion_common::ScalarValue; +use datafusion_expr::Expr; +use datafusion_functions_array::map::map; + +mod data_utils; + +fn build_keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn build_values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn t_batch(num: i32) -> RecordBatch { + let value: Vec = (0..num).collect(); + let c1: ArrayRef = Arc::new(Int32Array::from(value)); + RecordBatch::try_from_iter(vec![("c1", c1)]).unwrap() +} + +fn create_context(num: i32) -> datafusion_common::Result>> { + let ctx = SessionContext::new(); + ctx.register_batch("t", t_batch(num))?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(1).unwrap(); + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().table("t")).unwrap(); + + let mut rng = rand::thread_rng(); + let keys = build_keys(&mut rng); + let values = build_values(&mut rng); + let mut key_buffer = Vec::new(); + let mut value_buffer = Vec::new(); + + for i in 0..1000 { + key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + } + c.bench_function("map_1000_1", |b| { + b.iter(|| { + black_box( + rt.block_on( + df.clone() + .select(vec![map(key_buffer.clone(), value_buffer.clone())]) + .unwrap() + .collect(), + ) + .unwrap(), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/example.parquet b/datafusion/core/example.parquet new file mode 100644 index 000000000000..94de10394b33 Binary files /dev/null and b/datafusion/core/example.parquet differ diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index c953de6d16d3..a79f62e742be 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -314,7 +314,7 @@ impl InformationSchemaTablesBuilder { table_name: impl AsRef, table_type: TableType, ) { - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); self.schema_names.append_value(schema_name.as_ref()); self.table_names.append_value(table_name.as_ref()); @@ -405,7 +405,7 @@ impl InformationSchemaViewBuilder { table_name: impl AsRef, definition: Option>, ) { - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); self.schema_names.append_value(schema_name.as_ref()); self.table_names.append_value(table_name.as_ref()); diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c55b7c752765..fb28b5c1ab47 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,10 +1696,10 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 9a3aa2454e27..8b6a8800119d 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -66,7 +66,7 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; -#[derive(Default)] +#[derive(Default, Debug)] /// Factory struct used to create [ArrowFormat] pub struct ArrowFormatFactory; @@ -89,6 +89,10 @@ impl FileFormatFactory for ArrowFormatFactory { fn default(&self) -> Arc { Arc::new(ArrowFormat) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ArrowFormatFactory { @@ -353,7 +357,7 @@ async fn infer_schema_from_file_stream( // Expected format: // - 6 bytes // - 2 bytes - // - 4 bytes, not present below v0.15.0 + // - 4 bytes, not present below v0.15.0 // - 4 bytes // // @@ -365,7 +369,7 @@ async fn infer_schema_from_file_stream( // Files should start with these magic bytes if bytes[0..6] != ARROW_MAGIC { return Err(ArrowError::ParseError( - "Arrow file does not contian correct header".to_string(), + "Arrow file does not contain correct header".to_string(), ))?; } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index f4f9adcba7ed..5190bdbe153a 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::HashMap; +use std::fmt; use std::sync::Arc; use arrow::datatypes::Schema; @@ -64,6 +65,16 @@ impl FileFormatFactory for AvroFormatFactory { fn default(&self) -> Arc { Arc::new(AvroFormat) } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl fmt::Debug for AvroFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AvroFormatFactory").finish() + } } impl GetExt for AvroFormatFactory { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index baeaf51fb56d..e1b6daac092d 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -58,7 +58,8 @@ use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore} #[derive(Default)] /// Factory struct used to create [CsvFormatFactory] pub struct CsvFormatFactory { - options: Option, + /// the options for csv file read + pub options: Option, } impl CsvFormatFactory { @@ -75,6 +76,14 @@ impl CsvFormatFactory { } } +impl fmt::Debug for CsvFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvFormatFactory") + .field("options", &self.options) + .finish() + } +} + impl FileFormatFactory for CsvFormatFactory { fn create( &self, @@ -103,6 +112,10 @@ impl FileFormatFactory for CsvFormatFactory { fn default(&self) -> Arc { Arc::new(CsvFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for CsvFormatFactory { @@ -233,6 +246,18 @@ impl CsvFormat { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.options.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -330,6 +355,9 @@ impl FileFormat for CsvFormat { self.options.quote, self.options.escape, self.options.comment, + self.options + .newlines_in_values + .unwrap_or(state.config_options().catalog.newlines_in_values), self.options.compression.into(), ); Ok(Arc::new(exec)) @@ -645,7 +673,7 @@ mod tests { let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); - // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) + // skip column 9 that overflows the automatically discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); let exec = get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?; @@ -1052,6 +1080,41 @@ mod tests { Ok(()) } + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_newlines_in_values(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let csv_options = CsvReadOptions::default() + .has_header(true) + .newlines_in_values(true); + let ctx = SessionContext::new_with_config(config); + let testdata = arrow_test_data(); + ctx.register_csv( + "aggr", + &format!("{testdata}/csv/aggregate_test_100.csv"), + csv_options, + ) + .await?; + + let query = "select sum(c3) from aggr;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+--------------+", + "| sum(aggr.c3) |", + "+--------------+", + "| 781 |", + "+--------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set + + Ok(()) + } + /// Read a single empty csv file in parallel /// /// empty_0_byte.csv: @@ -1251,11 +1314,8 @@ mod tests { "+-----------------------+", "| 50 |", "+-----------------------+"]; - let file_size = if cfg!(target_os = "windows") { - 30 // new line on Win is '\r\n' - } else { - 20 - }; + + let file_size = std::fs::metadata("tests/data/one_col.csv")?.len() as usize; // A 20-Byte file at most get partitioned into 20 chunks let expected_partitions = if n_partitions <= file_size { n_partitions diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 007b084f504d..9de9c3d7d871 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -102,6 +102,10 @@ impl FileFormatFactory for JsonFormatFactory { fn default(&self) -> Arc { Arc::new(JsonFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for JsonFormatFactory { @@ -111,6 +115,14 @@ impl GetExt for JsonFormatFactory { } } +impl fmt::Debug for JsonFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JsonFormatFactory") + .field("options", &self.options) + .finish() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 1aa93a106aff..500f20af474f 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,11 +49,11 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; - +use std::fmt::Debug; /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats -pub trait FileFormatFactory: Sync + Send + GetExt { +pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// Initialize a [FileFormat] and configure based on session and command level options fn create( &self, @@ -63,6 +63,10 @@ pub trait FileFormatFactory: Sync + Send + GetExt { /// Initialize a [FileFormat] with all options set to default values fn default(&self) -> Arc; + + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } /// This trait abstracts all the file format specific implementations @@ -138,6 +142,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// The former trait is a superset of the latter trait, which includes execution time /// relevant methods. [FileType] is only used in logical planning and only implements /// the subset of methods required during logical planning. +#[derive(Debug)] pub struct DefaultFileType { file_format_factory: Arc, } @@ -149,6 +154,11 @@ impl DefaultFileType { file_format_factory, } } + + /// get a reference to the inner [FileFormatFactory] struct + pub fn as_format_factory(&self) -> &Arc { + &self.file_format_factory + } } impl FileType for DefaultFileType { @@ -159,7 +169,7 @@ impl FileType for DefaultFileType { impl Display for DefaultFileType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.file_format_factory.default().fmt(f) + write!(f, "{:?}", self.file_format_factory) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index c6d143ed6749..552977baba17 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -63,6 +63,14 @@ pub struct CsvReadOptions<'a> { pub escape: Option, /// If enabled, lines beginning with this byte are ignored. pub comment: Option, + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: bool, /// An optional schema representing the CSV files. If None, CSV reader will try to infer it /// based on data in file. pub schema: Option<&'a Schema>, @@ -95,6 +103,7 @@ impl<'a> CsvReadOptions<'a> { delimiter: b',', quote: b'"', escape: None, + newlines_in_values: false, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, @@ -133,6 +142,18 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = newlines_in_values; + self + } + /// Specify the file extension for CSV file selection pub fn file_extension(mut self, file_extension: &'a str) -> Self { self.file_extension = file_extension; @@ -490,6 +511,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_delimiter(self.delimiter) .with_quote(self.quote) .with_escape(self.escape) + .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) .with_file_compression_type(self.file_compression_type.to_owned()); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 6271d8af3786..3250b59fa1d1 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -140,6 +140,10 @@ impl FileFormatFactory for ParquetFormatFactory { fn default(&self) -> Arc { Arc::new(ParquetFormat::default()) } + + fn as_any(&self) -> &dyn Any { + self + } } impl GetExt for ParquetFormatFactory { @@ -149,6 +153,13 @@ impl GetExt for ParquetFormatFactory { } } +impl fmt::Debug for ParquetFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetFormatFactory") + .field("ParquetFormatFactory", &self.options) + .finish() + } +} /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { @@ -854,14 +865,14 @@ fn spawn_column_parallel_row_group_writer( let mut col_array_channels = Vec::with_capacity(num_columns); for writer in col_writers.into_iter() { // Buffer size of this channel limits the number of arrays queued up for column level serialization - let (send_array, recieve_array) = + let (send_array, receive_array) = mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); let reservation = MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); let task = SpawnedTask::spawn(column_serializer_task( - recieve_array, + receive_array, writer, reservation, )); @@ -936,7 +947,7 @@ fn spawn_rg_join_and_finalize_task( /// row group is reached, the parallel tasks are joined on another separate task /// and sent to a concatenation task. This task immediately continues to work /// on the next row group in parallel. So, parquet serialization is parallelized -/// accross both columns and row_groups, with a theoretical max number of parallel tasks +/// across both columns and row_groups, with a theoretical max number of parallel tasks /// given by n_columns * num_row_groups. fn spawn_parquet_parallel_serialization_task( mut data: Receiver, @@ -1560,7 +1571,7 @@ mod tests { // . batch1 written into first file and includes: // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values // . batch2 written into second file and includes: - // - column c2 that has 3 rows with one null. Stats min and max of int are avaialble and 1 and 2 respectively + // - column c2 that has 3 rows with one null. Stats min and max of int are available and 1 and 2 respectively let store = Arc::new(LocalFileSystem::new()) as _; let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; @@ -2112,7 +2123,7 @@ mod tests { let path_parts = path.parts().collect::>(); assert_eq!(path_parts.len(), 1, "should not have path prefix"); - assert_eq!(num_rows, 2, "file metdata to have 2 rows"); + assert_eq!(num_rows, 2, "file metadata to have 2 rows"); assert!( schema.iter().any(|col_schema| col_schema.name == "a"), "output file metadata should contain col a" @@ -2208,7 +2219,7 @@ mod tests { ); expected_partitions.remove(prefix); - assert_eq!(num_rows, 1, "file metdata to have 1 row"); + assert_eq!(num_rows, 1, "file metadata to have 1 row"); assert!( !schema.iter().any(|col_schema| col_schema.name == "a"), "output file metadata will not contain partitioned col a" diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index e29c877442cf..a58c77e31313 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -54,7 +54,7 @@ type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// which should be contained within the same output file. The outer channel /// is used to send a dynamic number of inner channels, representing a dynamic /// number of total output files. The caller is also responsible to monitor -/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// the demux task for errors and abort accordingly. The single_file_output parameter /// overrides all other settings to force only a single file to be written. /// partition_by parameter will additionally split the input based on the unique /// values of a specific column ``` diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 8bd0dae9f5a4..1d32063ee9f3 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -42,6 +42,37 @@ use tokio::task::JoinSet; type WriterType = Box; type SerializerType = Arc; +/// Result of calling [`serialize_rb_stream_to_object_store`] +pub(crate) enum SerializedRecordBatchResult { + Success { + /// the writer + writer: WriterType, + + /// the number of rows successfully written + row_count: usize, + }, + Failure { + /// As explained in [`serialize_rb_stream_to_object_store`]: + /// - If an IO error occured that involved the ObjectStore writer, then the writer will not be returned to the caller + /// - Otherwise, the writer is returned to the caller + writer: Option, + + /// the actual error that occured + err: DataFusionError, + }, +} + +impl SerializedRecordBatchResult { + /// Create the success variant + pub fn success(writer: WriterType, row_count: usize) -> Self { + Self::Success { writer, row_count } + } + + pub fn failure(writer: Option, err: DataFusionError) -> Self { + Self::Failure { writer, err } + } +} + /// Serializes a single data stream in parallel and writes to an ObjectStore concurrently. /// Data order is preserved. /// @@ -55,7 +86,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, mut writer: WriterType, -) -> std::result::Result<(WriterType, u64), (Option, DataFusionError)> { +) -> SerializedRecordBatchResult { let (tx, mut rx) = mpsc::channel::>>(100); let serialize_task = SpawnedTask::spawn(async move { @@ -86,43 +117,43 @@ pub(crate) async fn serialize_rb_stream_to_object_store( match writer.write_all(&bytes).await { Ok(_) => (), Err(e) => { - return Err(( + return SerializedRecordBatchResult::failure( None, DataFusionError::Execution(format!( "Error writing to object store: {e}" )), - )) + ) } }; row_count += cnt; } Ok(Err(e)) => { // Return the writer along with the error - return Err((Some(writer), e)); + return SerializedRecordBatchResult::failure(Some(writer), e); } Err(e) => { // Handle task panic or cancellation - return Err(( + return SerializedRecordBatchResult::failure( Some(writer), DataFusionError::Execution(format!( "Serialization task panicked or was cancelled: {e}" )), - )); + ); } } } match serialize_task.join().await { Ok(Ok(_)) => (), - Ok(Err(e)) => return Err((Some(writer), e)), + Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e), Err(_) => { - return Err(( + return SerializedRecordBatchResult::failure( Some(writer), internal_datafusion_err!("Unknown error writing to object store"), - )) + ) } } - Ok((writer, row_count as u64)) + SerializedRecordBatchResult::success(writer, row_count) } type FileWriteBundle = (Receiver, SerializerType, WriterType); @@ -141,7 +172,7 @@ pub(crate) async fn stateless_serialize_and_write_files( // tracks the specific error triggering abort let mut triggering_error = None; // tracks if any errors were encountered in the process of aborting writers. - // if true, we may not have a guarentee that all written data was cleaned up. + // if true, we may not have a guarantee that all written data was cleaned up. let mut any_abort_errors = false; let mut join_set = JoinSet::new(); while let Some((data_rx, serializer, writer)) = rx.recv().await { @@ -153,14 +184,17 @@ pub(crate) async fn stateless_serialize_and_write_files( while let Some(result) = join_set.join_next().await { match result { Ok(res) => match res { - Ok((writer, cnt)) => { + SerializedRecordBatchResult::Success { + writer, + row_count: cnt, + } => { finished_writers.push(writer); row_count += cnt; } - Err((writer, e)) => { + SerializedRecordBatchResult::Failure { writer, err } => { finished_writers.extend(writer); any_errors = true; - triggering_error = Some(e); + triggering_error = Some(err); } }, Err(e) => { @@ -188,12 +222,12 @@ pub(crate) async fn stateless_serialize_and_write_files( true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), false => match triggering_error { Some(e) => return Err(e), - None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.") } } } - tx.send(row_count).map_err(|_| { + tx.send(row_count as u64).map_err(|_| { internal_datafusion_err!( "Error encountered while sending row count back to file sink!" ) @@ -268,7 +302,7 @@ pub(crate) async fn stateless_multipart_put( r2?; let total_count = rx_row_cnt.await.map_err(|_| { - internal_datafusion_err!("Did not receieve row count from write coordinater") + internal_datafusion_err!("Did not receive row count from write coordinator") })?; Ok(total_count) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index c1ce4cc5b6c5..bfc33ce0bd73 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -759,7 +759,7 @@ mod tests { .otherwise(lit(false)) .expect("valid case expr")) )); - // static expression not relvant in this context but we + // static expression not relevant in this context but we // test it as an edge case anyway in case we want to generalize // this helper function assert!(expr_applicable_for_cols(&[], &lit(true))); diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index ea4d396a14cb..4d0a7738b039 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -695,8 +695,8 @@ impl ListingTable { } /// Specify the SQL definition for this table, if any - pub fn with_definition(mut self, defintion: Option) -> Self { - self.definition = defintion; + pub fn with_definition(mut self, definition: Option) -> Self { + self.definition = definition; self } @@ -1038,8 +1038,8 @@ mod tests { use crate::datasource::file_format::avro::AvroFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::json::JsonFormat; - use crate::datasource::file_format::parquet::ParquetFormat; #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::physical_plan::collect; diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 73fffd8abaed..7566df628ed7 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -53,7 +53,7 @@ impl ListingTableUrl { /// subdirectories. /// /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, - /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// whereas `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the /// S3 bucket `BUCKET` /// /// # URL Encoding diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index aab42285a0b2..5c4928209559 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -644,7 +644,7 @@ mod tests { Ok(partitions) } - /// Returns the value of results. For example, returns 6 given the follwing + /// Returns the value of results. For example, returns 6 given the following /// /// ```text /// +-------+, diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 327fbd976e87..fb0e23c6c164 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -59,6 +59,7 @@ pub struct CsvExec { quote: u8, escape: Option, comment: Option, + newlines_in_values: bool, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Compression type of the file associated with CsvExec @@ -68,6 +69,7 @@ pub struct CsvExec { impl CsvExec { /// Create a new CSV reader execution plan provided base and specific configurations + #[allow(clippy::too_many_arguments)] pub fn new( base_config: FileScanConfig, has_header: bool, @@ -75,6 +77,7 @@ impl CsvExec { quote: u8, escape: Option, comment: Option, + newlines_in_values: bool, file_compression_type: FileCompressionType, ) -> Self { let (projected_schema, projected_statistics, projected_output_ordering) = @@ -91,6 +94,7 @@ impl CsvExec { delimiter, quote, escape, + newlines_in_values, metrics: ExecutionPlanMetricsSet::new(), file_compression_type, cache, @@ -126,6 +130,17 @@ impl CsvExec { self.escape } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(&self) -> bool { + self.newlines_in_values + } + fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) } @@ -196,15 +211,15 @@ impl ExecutionPlan for CsvExec { /// Redistribute files across partitions according to their size /// See comments on [`FileGroupPartitioner`] for more detail. /// - /// Return `None` if can't get repartitioned(empty/compressed file). + /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - // Parallel execution on compressed CSV file is not supported yet. - if self.file_compression_type.is_compressed() { + // Parallel execution on compressed CSV files or files that must support newlines in values is not supported yet. + if self.file_compression_type.is_compressed() || self.newlines_in_values { return Ok(None); } @@ -589,6 +604,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -658,6 +674,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -727,6 +744,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -793,6 +811,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(14, csv.base_config.file_schema.fields().len()); @@ -858,6 +877,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); assert_eq!(13, csv.base_config.file_schema.fields().len()); @@ -953,6 +973,7 @@ mod tests { b'"', None, None, + false, file_compression_type.to_owned(), ); diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index f5d3c7a6410d..17850ea7585a 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -56,7 +56,7 @@ pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { } /// Convert a [`ScalarValue`] of partition columns to a type, as -/// decribed in the documentation of [`wrap_partition_type_in_dict`], +/// described in the documentation of [`wrap_partition_type_in_dict`], /// which can wrap the types. pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) @@ -682,7 +682,7 @@ mod tests { vec![table_partition_col.clone()], ); - // verify the proj_schema inlcudes the last column and exactly the same the field it is defined + // verify the proj_schema includes the last column and exactly the same the field it is defined let (proj_schema, _proj_statistics, _) = conf.project(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ed0fc5f0169e..1eea4eab8ba2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::file_stream::FileStream; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, + parquet::page_filter::PagePruningAccessPlanFilter, DisplayAs, FileGroupPartitioner, FileScanConfig, }; use crate::{ @@ -39,13 +39,11 @@ use crate::{ }, }; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::SchemaRef; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; use itertools::Itertools; use log::debug; -use parquet::basic::{ConvertedType, LogicalType}; -use parquet::schema::types::ColumnDescriptor; mod access_plan; mod metrics; @@ -225,7 +223,7 @@ pub struct ParquetExec { /// Optional predicate for pruning row groups (derived from `predicate`) pruning_predicate: Option>, /// Optional predicate for pruning pages (derived from `predicate`) - page_pruning_predicate: Option>, + page_pruning_predicate: Option>, /// Optional hint for the size of the parquet metadata metadata_size_hint: Option, /// Optional user defined parquet file reader factory @@ -381,19 +379,12 @@ impl ParquetExecBuilder { }) .filter(|p| !p.always_true()); - let page_pruning_predicate = predicate.as_ref().and_then(|predicate_expr| { - match PagePruningPredicate::try_new(predicate_expr, file_schema.clone()) { - Ok(pruning_predicate) => Some(Arc::new(pruning_predicate)), - Err(e) => { - debug!( - "Could not create page pruning predicate for '{:?}': {}", - pruning_predicate, e - ); - predicate_creation_errors.add(1); - None - } - } - }); + let page_pruning_predicate = predicate + .as_ref() + .map(|predicate_expr| { + PagePruningAccessPlanFilter::new(predicate_expr, file_schema.clone()) + }) + .map(Arc::new); let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); @@ -739,7 +730,7 @@ impl ExecutionPlan for ParquetExec { fn should_enable_page_index( enable_page_index: bool, - page_pruning_predicate: &Option>, + page_pruning_predicate: &Option>, ) -> bool { enable_page_index && page_pruning_predicate.is_some() @@ -749,26 +740,6 @@ fn should_enable_page_index( .unwrap_or(false) } -// Convert parquet column schema to arrow data type, and just consider the -// decimal data type. -pub(crate) fn parquet_to_arrow_decimal_type( - parquet_column: &ColumnDescriptor, -) -> Option { - let type_ptr = parquet_column.self_type_ptr(); - match type_ptr.get_basic_info().logical_type() { - Some(LogicalType::Decimal { scale, precision }) => { - Some(DataType::Decimal128(precision as u8, scale as i8)) - } - _ => match type_ptr.get_basic_info().converted_type() { - ConvertedType::DECIMAL => Some(DataType::Decimal128( - type_ptr.get_precision() as u8, - type_ptr.get_scale() as i8, - )), - _ => None, - }, - } -} - #[cfg(test)] mod tests { // See also `parquet_exec` integration test @@ -798,7 +769,7 @@ mod tests { }; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; - use arrow_schema::Fields; + use arrow_schema::{DataType, Fields}; use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::planner::logical2physical; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index c97b0282626a..ffe879eb8de0 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,7 +17,7 @@ //! [`ParquetOpener`] for opening Parquet files -use crate::datasource::physical_plan::parquet::page_filter::PagePruningPredicate; +use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ row_filter, should_enable_page_index, ParquetAccessPlan, @@ -46,7 +46,7 @@ pub(super) struct ParquetOpener { pub limit: Option, pub predicate: Option>, pub pruning_predicate: Option>, - pub page_pruning_predicate: Option>, + pub page_pruning_predicate: Option>, pub table_schema: SchemaRef, pub metadata_size_hint: Option, pub metrics: ExecutionPlanMetricsSet, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 7429ca593820..d658608ab4f1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -17,40 +17,33 @@ //! Contains code to filter entire pages -use arrow::array::{ - BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array, - StringArray, -}; -use arrow::datatypes::DataType; +use crate::datasource::physical_plan::parquet::ParquetAccessPlan; +use crate::datasource::physical_plan::parquet::StatisticsConverter; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use arrow::array::BooleanArray; use arrow::{array::ArrayRef, datatypes::SchemaRef}; use arrow_schema::Schema; -use datafusion_common::{Result, ScalarValue}; -use datafusion_physical_expr::expressions::Column; +use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; -use parquet::schema::types::{ColumnDescriptor, SchemaDescriptor}; +use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; +use parquet::format::PageLocation; +use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::arrow_reader::{RowSelection, RowSelector}, - file::{ - metadata::{ParquetMetaData, RowGroupMetaData}, - page_index::index::Index, - }, - format::PageLocation, + file::metadata::{ParquetMetaData, RowGroupMetaData}, }; use std::collections::HashSet; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; -use crate::datasource::physical_plan::parquet::statistics::{ - from_bytes_to_i128, parquet_column, -}; -use crate::datasource::physical_plan::parquet::ParquetAccessPlan; -use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; - use super::metrics::ParquetFileMetrics; -/// A [`PagePruningPredicate`] provides the ability to construct a [`RowSelection`] -/// based on parquet page level statistics, if any +/// Filters a [`ParquetAccessPlan`] based on the [Parquet PageIndex], if present +/// +/// It does so by evaluating statistics from the [`ParquetColumnIndex`] and +/// [`ParquetOffsetIndex`] and converting them to [`RowSelection`]. +/// +/// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// /// For example, given a row group with two column (chunks) for `A` /// and `B` with the following with page level statistics: @@ -103,30 +96,52 @@ use super::metrics::ParquetFileMetrics; /// /// So we can entirely skip rows 0->199 and 250->299 as we know they /// can not contain rows that match the predicate. +/// +/// # Implementation notes +/// +/// Single column predicates are evaluated using the PageIndex information +/// for that column to determine which row ranges can be skipped based. +/// +/// The resulting [`RowSelection`]'s are combined into a final +/// row selection that is added to the [`ParquetAccessPlan`]. #[derive(Debug)] -pub struct PagePruningPredicate { +pub struct PagePruningAccessPlanFilter { + /// single column predicates (e.g. (`col = 5`) extracted from the overall + /// predicate. Must all be true for a row to be included in the result. predicates: Vec, } -impl PagePruningPredicate { - /// Create a new [`PagePruningPredicate`] - // TODO: this is infallaible -- it can not return an error - pub fn try_new(expr: &Arc, schema: SchemaRef) -> Result { +impl PagePruningAccessPlanFilter { + /// Create a new [`PagePruningAccessPlanFilter`] from a physical + /// expression. + pub fn new(expr: &Arc, schema: SchemaRef) -> Self { + // extract any single column predicates let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { - match PruningPredicate::try_new(predicate.clone(), schema.clone()) { - Ok(p) - if (!p.always_true()) - && (p.required_columns().n_columns() < 2) => - { - Some(Ok(p)) - } - _ => None, + let pp = + match PruningPredicate::try_new(predicate.clone(), schema.clone()) { + Ok(pp) => pp, + Err(e) => { + debug!("Ignoring error creating page pruning predicate: {e}"); + return None; + } + }; + + if pp.always_true() { + debug!("Ignoring always true page pruning predicate: {predicate}"); + return None; + } + + if pp.required_columns().single_column().is_none() { + debug!("Ignoring multi-column page pruning predicate: {predicate}"); + return None; } + + Some(pp) }) - .collect::>>()?; - Ok(Self { predicates }) + .collect::>(); + Self { predicates } } /// Returns an updated [`ParquetAccessPlan`] by applying predicates to the @@ -136,7 +151,7 @@ impl PagePruningPredicate { mut access_plan: ParquetAccessPlan, arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, - file_metadata: &ParquetMetaData, + parquet_metadata: &ParquetMetaData, file_metrics: &ParquetFileMetrics, ) -> ParquetAccessPlan { // scoped timer updates on drop @@ -146,18 +161,18 @@ impl PagePruningPredicate { } let page_index_predicates = &self.predicates; - let groups = file_metadata.row_groups(); + let groups = parquet_metadata.row_groups(); if groups.is_empty() { return access_plan; } - let (Some(file_offset_indexes), Some(file_page_indexes)) = - (file_metadata.offset_index(), file_metadata.column_index()) - else { - trace!( - "skip page pruning due to lack of indexes. Have offset: {}, column index: {}", - file_metadata.offset_index().is_some(), file_metadata.column_index().is_some() + if parquet_metadata.offset_index().is_none() + || parquet_metadata.column_index().is_none() + { + debug!( + "Can not prune pages due to lack of indexes. Have offset: {}, column index: {}", + parquet_metadata.offset_index().is_some(), parquet_metadata.column_index().is_some() ); return access_plan; }; @@ -165,33 +180,39 @@ impl PagePruningPredicate { // track the total number of rows that should be skipped let mut total_skip = 0; + // for each row group specified in the access plan let row_group_indexes = access_plan.row_group_indexes(); - for r in row_group_indexes { + for row_group_index in row_group_indexes { // The selection for this particular row group let mut overall_selection = None; for predicate in page_index_predicates { - // find column index in the parquet schema - let col_idx = find_column_index(predicate, arrow_schema, parquet_schema); - let row_group_metadata = &groups[r]; - - let (Some(rg_page_indexes), Some(rg_offset_indexes), Some(col_idx)) = ( - file_page_indexes.get(r), - file_offset_indexes.get(r), - col_idx, - ) else { - trace!( - "Did not have enough metadata to prune with page indexes, \ - falling back to all rows", - ); - continue; + let column = predicate + .required_columns() + .single_column() + .expect("Page pruning requires single column predicates"); + + let converter = StatisticsConverter::try_new( + column.name(), + arrow_schema, + parquet_schema, + ); + + let converter = match converter { + Ok(converter) => converter, + Err(e) => { + debug!( + "Could not create statistics converter for column {}: {e}", + column.name() + ); + continue; + } }; let selection = prune_pages_in_one_row_group( - row_group_metadata, + row_group_index, predicate, - rg_offset_indexes.get(col_idx), - rg_page_indexes.get(col_idx), - groups[r].column(col_idx).column_descr(), + converter, + parquet_metadata, file_metrics, ); @@ -224,15 +245,15 @@ impl PagePruningPredicate { let rows_skipped = rows_skipped(&overall_selection); trace!("Overall selection from predicate skipped {rows_skipped}: {overall_selection:?}"); total_skip += rows_skipped; - access_plan.scan_selection(r, overall_selection) + access_plan.scan_selection(row_group_index, overall_selection) } else { // Selection skips all rows, so skip the entire row group - let rows_skipped = groups[r].num_rows() as usize; - access_plan.skip(r); + let rows_skipped = groups[row_group_index].num_rows() as usize; + access_plan.skip(row_group_index); total_skip += rows_skipped; trace!( "Overall selection from predicate is empty, \ - skipping all {rows_skipped} rows in row group {r}" + skipping all {rows_skipped} rows in row group {row_group_index}" ); } } @@ -242,7 +263,7 @@ impl PagePruningPredicate { access_plan } - /// Returns the number of filters in the [`PagePruningPredicate`] + /// Returns the number of filters in the [`PagePruningAccessPlanFilter`] pub fn filter_number(&self) -> usize { self.predicates.len() } @@ -266,97 +287,53 @@ fn update_selection( } } -/// Returns the column index in the row parquet schema for the single -/// column of a single column pruning predicate. -/// -/// For example, give the predicate `y > 5` +/// Returns a [`RowSelection`] for the rows in this row group to scan. /// -/// And columns in the RowGroupMetadata like `['x', 'y', 'z']` will -/// return 1. +/// This Row Selection is formed from the page index and the predicate skips row +/// ranges that can be ruled out based on the predicate. /// -/// Returns `None` if the column is not found, or if there are no -/// required columns, which is the case for predicate like `abs(i) = -/// 1` which are rewritten to `lit(true)` -/// -/// Panics: -/// -/// If the predicate contains more than one column reference (assumes -/// that `extract_page_index_push_down_predicates` only returns -/// predicate with one col) -fn find_column_index( - predicate: &PruningPredicate, - arrow_schema: &Schema, - parquet_schema: &SchemaDescriptor, -) -> Option { - let mut found_required_column: Option<&Column> = None; - - for required_column_details in predicate.required_columns().iter() { - let column = &required_column_details.0; - if let Some(found_required_column) = found_required_column.as_ref() { - // make sure it is the same name we have seen previously - assert_eq!( - column.name(), - found_required_column.name(), - "Unexpected multi column predicate" - ); - } else { - found_required_column = Some(column); - } - } - - let Some(column) = found_required_column.as_ref() else { - trace!("No column references in pruning predicate"); - return None; - }; - - parquet_column(parquet_schema, arrow_schema, column.name()).map(|x| x.0) -} - -/// Returns a `RowSelection` for the pages in this RowGroup if any -/// rows can be pruned based on the page index +/// Returns `None` if there is an error evaluating the predicate or the required +/// page information is not present. fn prune_pages_in_one_row_group( - group: &RowGroupMetaData, - predicate: &PruningPredicate, - col_offset_indexes: Option<&Vec>, - col_page_indexes: Option<&Index>, - col_desc: &ColumnDescriptor, + row_group_index: usize, + pruning_predicate: &PruningPredicate, + converter: StatisticsConverter<'_>, + parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, ) -> Option { - let num_rows = group.num_rows() as usize; - let (Some(col_offset_indexes), Some(col_page_indexes)) = - (col_offset_indexes, col_page_indexes) - else { - return None; - }; - - let target_type = parquet_to_arrow_decimal_type(col_desc); - let pruning_stats = PagesPruningStatistics { - col_page_indexes, - col_offset_indexes, - target_type: &target_type, - num_rows_in_row_group: group.num_rows(), - }; + let pruning_stats = + PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; - let values = match predicate.prune(&pruning_stats) { + // Each element in values is a boolean indicating whether the page may have + // values that match the predicate (true) or could not possibly have values + // that match the predicate (false). + let values = match pruning_predicate.prune(&pruning_stats) { Ok(values) => values, Err(e) => { - // stats filter array could not be built - // return a result which will not filter out any pages debug!("Error evaluating page index predicate values {e}"); metrics.predicate_evaluation_errors.add(1); return None; } }; + // Convert the information of which pages to skip into a RowSelection + // that describes the ranges of rows to skip. + let Some(page_row_counts) = pruning_stats.page_row_counts() else { + debug!( + "Can not determine page row counts for row group {row_group_index}, skipping" + ); + metrics.predicate_evaluation_errors.add(1); + return None; + }; + let mut vec = Vec::with_capacity(values.len()); - let row_vec = create_row_count_in_each_page(col_offset_indexes, num_rows); - assert_eq!(row_vec.len(), values.len()); - let mut sum_row = *row_vec.first().unwrap(); + assert_eq!(page_row_counts.len(), values.len()); + let mut sum_row = *page_row_counts.first().unwrap(); let mut selected = *values.first().unwrap(); trace!("Pruned to {:?} using {:?}", values, pruning_stats); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { - sum_row += *row_vec.get(i).unwrap(); + sum_row += *page_row_counts.get(i).unwrap(); } else { let selector = if selected { RowSelector::select(sum_row) @@ -364,7 +341,7 @@ fn prune_pages_in_one_row_group( RowSelector::skip(sum_row) }; vec.push(selector); - sum_row = *row_vec.get(i).unwrap(); + sum_row = *page_row_counts.get(i).unwrap(); selected = f; } } @@ -378,206 +355,143 @@ fn prune_pages_in_one_row_group( Some(RowSelection::from(vec)) } -fn create_row_count_in_each_page( - location: &[PageLocation], - num_rows: usize, -) -> Vec { - let mut vec = Vec::with_capacity(location.len()); - location.windows(2).for_each(|x| { - let start = x[0].first_row_index as usize; - let end = x[1].first_row_index as usize; - vec.push(end - start); - }); - vec.push(num_rows - location.last().unwrap().first_row_index as usize); - vec -} - -/// Wraps one col page_index in one rowGroup statistics in a way -/// that implements [`PruningStatistics`] +/// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) #[derive(Debug)] struct PagesPruningStatistics<'a> { - col_page_indexes: &'a Index, - col_offset_indexes: &'a Vec, - // target_type means the logical type in schema: like 'DECIMAL' is the logical type, but the - // real physical type in parquet file may be `INT32, INT64, FIXED_LEN_BYTE_ARRAY` - target_type: &'a Option, - num_rows_in_row_group: i64, + row_group_index: usize, + row_group_metadatas: &'a [RowGroupMetaData], + converter: StatisticsConverter<'a>, + column_index: &'a ParquetColumnIndex, + offset_index: &'a ParquetOffsetIndex, + page_offsets: &'a Vec, } -// Extract the min or max value calling `func` from page idex -macro_rules! get_min_max_values_for_page_index { - ($self:expr, $func:ident) => {{ - match $self.col_page_indexes { - Index::NONE => None, - Index::INT32(index) => { - match $self.target_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::INT64(index) => { - match $self.target_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::FLOAT(index) => { - let vec = &index.indexes; - Some(Arc::new(Float32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::DOUBLE(index) => { - let vec = &index.indexes; - Some(Arc::new(Float64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BOOLEAN(index) => { - let vec = &index.indexes; - Some(Arc::new(BooleanArray::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - let array: StringArray = vec - .iter() - .map(|x| x.$func()) - .map(|x| x.and_then(|x| std::str::from_utf8(x.as_ref()).ok())) - .collect(); - Some(Arc::new(array)) - } - }, - Index::INT96(_) => { - //Todo support these type - None - } - Index::FIXED_LEN_BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => None, - }, - } - }}; +impl<'a> PagesPruningStatistics<'a> { + /// Creates a new [`PagesPruningStatistics`] for a column in a row group, if + /// possible. + /// + /// Returns None if the `parquet_metadata` does not have sufficient + /// information to create the statistics. + fn try_new( + row_group_index: usize, + converter: StatisticsConverter<'a>, + parquet_metadata: &'a ParquetMetaData, + ) -> Option { + let Some(parquet_column_index) = converter.parquet_index() else { + trace!( + "Column {:?} not in parquet file, skipping", + converter.arrow_field() + ); + return None; + }; + + let column_index = parquet_metadata.column_index()?; + let offset_index = parquet_metadata.offset_index()?; + let row_group_metadatas = parquet_metadata.row_groups(); + + let Some(row_group_page_offsets) = offset_index.get(row_group_index) else { + trace!("No page offsets for row group {row_group_index}, skipping"); + return None; + }; + let Some(page_offsets) = row_group_page_offsets.get(parquet_column_index) else { + trace!( + "No page offsets for column {:?} in row group {row_group_index}, skipping", + converter.arrow_field() + ); + return None; + }; + + Some(Self { + row_group_index, + row_group_metadatas, + converter, + column_index, + offset_index, + page_offsets, + }) + } + + /// return the row counts in each data page, if possible. + fn page_row_counts(&self) -> Option> { + let row_group_metadata = self + .row_group_metadatas + .get(self.row_group_index) + // fail fast/panic if row_group_index is out of bounds + .unwrap(); + + let num_rows_in_row_group = row_group_metadata.num_rows() as usize; + + let page_offsets = self.page_offsets; + let mut vec = Vec::with_capacity(page_offsets.len()); + page_offsets.windows(2).for_each(|x| { + let start = x[0].first_row_index as usize; + let end = x[1].first_row_index as usize; + vec.push(end - start); + }); + vec.push(num_rows_in_row_group - page_offsets.last()?.first_row_index as usize); + Some(vec) + } } impl<'a> PruningStatistics for PagesPruningStatistics<'a> { fn min_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, min) + match self.converter.data_page_mins( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page min values {e}"); + None + } + } } fn max_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, max) + match self.converter.data_page_maxes( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page max values {e}"); + None + } + } } fn num_containers(&self) -> usize { - self.col_offset_indexes.len() + self.page_offsets.len() } fn null_counts(&self, _column: &datafusion_common::Column) -> Option { - match self.col_page_indexes { - Index::NONE => None, - Index::BOOLEAN(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT32(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT64(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FLOAT(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::DOUBLE(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT96(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FIXED_LEN_BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), + match self.converter.data_page_null_counts( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(null_counts) => Some(Arc::new(null_counts)), + Err(e) => { + debug!("Error evaluating data page null counts {e}"); + None + } } } fn row_counts(&self, _column: &datafusion_common::Column) -> Option { - // see https://github.com/apache/arrow-rs/blob/91f0b1771308609ca27db0fb1d2d49571b3980d8/parquet/src/file/metadata.rs#L979-L982 - - let row_count_per_page = self.col_offset_indexes.windows(2).map(|location| { - Some(location[1].first_row_index - location[0].first_row_index) - }); - - // append the last page row count - let row_count_per_page = row_count_per_page.chain(std::iter::once(Some( - self.num_rows_in_row_group - - self.col_offset_indexes.last().unwrap().first_row_index, - ))); - - Some(Arc::new(Int64Array::from_iter(row_count_per_page))) + match self.converter.data_page_row_counts( + self.offset_index, + self.row_group_metadatas, + [&self.row_group_index], + ) { + Ok(row_counts) => row_counts.map(|a| Arc::new(a) as ArrayRef), + Err(e) => { + debug!("Error evaluating data page row counts {e}"); + None + } + } } fn contained( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 9bc79805746f..170beb15ead2 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -1123,7 +1123,7 @@ mod tests { } #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { + async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr() { BloomFilterTest::new_data_index_bloom_encoding_stats() .with_expect_all_pruned() // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 59369aba57a9..3d250718f736 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -1136,6 +1136,16 @@ pub struct StatisticsConverter<'a> { } impl<'a> StatisticsConverter<'a> { + /// Return the index of the column in the parquet file, if any + pub fn parquet_index(&self) -> Option { + self.parquet_index + } + + /// Return the arrow field of the column in the arrow schema + pub fn arrow_field(&self) -> &'a Field { + self.arrow_field + } + /// Returns a [`UInt64Array`] with row counts for each row group /// /// # Return Value @@ -2532,10 +2542,10 @@ mod test { fn timestamp_seconds_array( input: impl IntoIterator>, - timzezone: Option<&str>, + timezone: Option<&str>, ) -> ArrayRef { let array: TimestampSecondArray = input.into_iter().collect(); - match timzezone { + match timezone { Some(tz) => Arc::new(array.with_timezone(tz)), None => Arc::new(array), } @@ -2543,10 +2553,10 @@ mod test { fn timestamp_milliseconds_array( input: impl IntoIterator>, - timzezone: Option<&str>, + timezone: Option<&str>, ) -> ArrayRef { let array: TimestampMillisecondArray = input.into_iter().collect(); - match timzezone { + match timezone { Some(tz) => Arc::new(array.with_timezone(tz)), None => Arc::new(array), } @@ -2554,10 +2564,10 @@ mod test { fn timestamp_microseconds_array( input: impl IntoIterator>, - timzezone: Option<&str>, + timezone: Option<&str>, ) -> ArrayRef { let array: TimestampMicrosecondArray = input.into_iter().collect(); - match timzezone { + match timezone { Some(tz) => Arc::new(array.with_timezone(tz)), None => Arc::new(array), } @@ -2565,10 +2575,10 @@ mod test { fn timestamp_nanoseconds_array( input: impl IntoIterator>, - timzezone: Option<&str>, + timezone: Option<&str>, ) -> ArrayRef { let array: TimestampNanosecondArray = input.into_iter().collect(); - match timzezone { + match timezone { Some(tz) => Arc::new(array.with_timezone(tz)), None => Arc::new(array), } diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index e8b64e90900c..f485c49e9109 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -246,12 +246,13 @@ mod tests { use crate::datasource::schema_adapter::{ SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + #[cfg(feature = "parquet")] use parquet::arrow::ArrowWriter; use tempfile::TempDir; #[tokio::test] async fn can_override_schema_adapter() { - // Test shows that SchemaAdapter can add a column that doesn't existin in the + // Test shows that SchemaAdapter can add a column that doesn't existing in the // record batches returned from parquet. This can be useful for schema evolution // where older files may not have all columns. let tmp_dir = TempDir::new().unwrap(); diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index 0ba6f85ec3e2..205faee43334 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -50,7 +50,7 @@ impl StreamingTable { if !schema.contains(partition_schema) { debug!( "target schema does not contain partition schema. \ - Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + Target_schema: {schema:?}. Partition Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 640a9b14a65f..ac48788edb19 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -500,7 +500,7 @@ impl SessionContext { self.execute_logical_plan(plan).await } - /// Creates logical expresssions from SQL query text. + /// Creates logical expressions from SQL query text. /// /// # Example: Parsing SQL queries /// @@ -1352,7 +1352,7 @@ impl SessionContext { self.state.write().register_catalog_list(catalog_list) } - /// Registers a [`ConfigExtension`] as a table option extention that can be + /// Registers a [`ConfigExtension`] as a table option extension that can be /// referenced from SQL statements executed against this context. pub fn register_table_options_extension(&self, extension: T) { self.state diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index ac02c7317256..a1b3eab25f33 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -19,6 +19,9 @@ pub mod context; pub mod session_state; +mod session_state_defaults; + +pub use session_state_defaults::SessionStateDefaults; // backwards compatibility pub use crate::datasource::file_format::options; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 03ce8d3b5892..59cc620dae4d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -18,30 +18,17 @@ //! [`SessionState`]: information required to run queries in a session use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}; -use crate::catalog::listing_schema::ListingSchemaProvider; -use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; -use crate::catalog::{ - CatalogProvider, CatalogProviderList, MemoryCatalogProvider, - MemoryCatalogProviderList, -}; +use crate::catalog::schema::SchemaProvider; +use crate::catalog::{CatalogProviderList, MemoryCatalogProviderList}; use crate::datasource::cte_worktable::CteWorkTable; -use crate::datasource::file_format::arrow::ArrowFormatFactory; -use crate::datasource::file_format::avro::AvroFormatFactory; -use crate::datasource::file_format::csv::CsvFormatFactory; -use crate::datasource::file_format::json::JsonFormatFactory; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormatFactory; use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; use crate::datasource::function::{TableFunction, TableFunctionImpl}; -use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; +use crate::datasource::provider::TableProviderFactory; use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; -#[cfg(feature = "array_expressions")] -use crate::functions_array; +use crate::execution::SessionStateDefaults; use crate::physical_optimizer::optimizer::PhysicalOptimizer; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; -use crate::{functions, functions_aggregate}; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -55,7 +42,6 @@ use datafusion_common::{ ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; -use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; @@ -74,6 +60,7 @@ use datafusion_optimizer::{ }; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; @@ -85,7 +72,6 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::sync::Arc; -use url::Url; use uuid::Uuid; /// Execution context for registering data sources and executing queries. @@ -1019,7 +1005,7 @@ impl SessionStateBuilder { self } - /// Set tje [`PhysicalOptimizerRule`]s used to optimize plans. + /// Set the [`PhysicalOptimizerRule`]s used to optimize plans. pub fn with_physical_optimizer_rules( mut self, physical_optimizers: Vec>, @@ -1420,169 +1406,6 @@ impl From for SessionStateBuilder { } } -/// Defaults that are used as part of creating a SessionState such as table providers, -/// file formats, registering of builtin functions, etc. -pub struct SessionStateDefaults {} - -impl SessionStateDefaults { - /// returns a map of the default [`TableProviderFactory`]s - pub fn default_table_factories() -> HashMap> { - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - table_factories - } - - /// returns the default MemoryCatalogProvider - pub fn default_catalog( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - ) -> MemoryCatalogProvider { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema(config, table_factories, runtime, &default_catalog); - - default_catalog - } - - /// returns the list of default [`ExprPlanner`]s - pub fn default_expr_planners() -> Vec> { - let expr_planners: Vec> = vec![ - Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), - ]; - - expr_planners - } - - /// returns the list of default [`ScalarUDF']'s - pub fn default_scalar_functions() -> Vec> { - let mut functions: Vec> = functions::all_default_functions(); - #[cfg(feature = "array_expressions")] - functions.append(&mut functions_array::all_default_array_functions()); - - functions - } - - /// returns the list of default [`AggregateUDF']'s - pub fn default_aggregate_functions() -> Vec> { - functions_aggregate::all_default_aggregate_functions() - } - - /// returns the list of default [`FileFormatFactory']'s - pub fn default_file_formats() -> Vec> { - let file_formats: Vec> = vec![ - #[cfg(feature = "parquet")] - Arc::new(ParquetFormatFactory::new()), - Arc::new(JsonFormatFactory::new()), - Arc::new(CsvFormatFactory::new()), - Arc::new(ArrowFormatFactory::new()), - Arc::new(AvroFormatFactory::new()), - ]; - - file_formats - } - - /// registers all builtin functions - scalar, array and aggregate - pub fn register_builtin_functions(state: &mut SessionState) { - Self::register_scalar_functions(state); - Self::register_array_functions(state); - Self::register_aggregate_functions(state); - } - - /// registers all the builtin scalar functions - pub fn register_scalar_functions(state: &mut SessionState) { - functions::register_all(state).expect("can not register built in functions"); - } - - /// registers all the builtin array functions - pub fn register_array_functions(state: &mut SessionState) { - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(state).expect("can not register array expressions"); - } - - /// registers all the builtin aggregate functions - pub fn register_aggregate_functions(state: &mut SessionState) { - functions_aggregate::register_all(state) - .expect("can not register aggregate functions"); - } - - /// registers the default schema - pub fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = - ListingSchemaProvider::new(authority, path, factory.clone(), store, format); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); - } - - /// registers the default [`FileFormatFactory`]s - pub fn register_default_file_formats(state: &mut SessionState) { - let formats = SessionStateDefaults::default_file_formats(); - for format in formats { - if let Err(e) = state.register_file_format(format, false) { - log::info!("Unable to register default file format: {e}") - }; - } - } -} - /// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] /// /// This is used so the SQL planner can access the state of the session without diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs new file mode 100644 index 000000000000..0b0465e44605 --- /dev/null +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::catalog::listing_schema::ListingSchemaProvider; +use crate::catalog::{CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider}; +use crate::datasource::file_format::arrow::ArrowFormatFactory; +use crate::datasource::file_format::avro::AvroFormatFactory; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::json::JsonFormatFactory; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormatFactory; +use crate::datasource::file_format::FileFormatFactory; +use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; +use crate::execution::context::SessionState; +#[cfg(feature = "array_expressions")] +use crate::functions_array; +use crate::{functions, functions_aggregate}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; + +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} + +impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + /// returns the default MemoryCatalogProvider + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + /// returns the list of default [`ExprPlanner`]s + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + /// returns the list of default [`ScalarUDF']'s + pub fn default_scalar_functions() -> Vec> { + let mut functions: Vec> = functions::all_default_functions(); + #[cfg(feature = "array_expressions")] + functions.append(&mut functions_array::all_default_array_functions()); + + functions + } + + /// returns the list of default [`AggregateUDF']'s + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns the list of default [`FileFormatFactory']'s + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + functions_array::register_all(state).expect("can not register array expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } + } +} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 81c1c4629a3a..9b9b1db8ff81 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -199,7 +199,7 @@ //! [`QueryPlanner`]: execution::context::QueryPlanner //! [`OptimizerRule`]: datafusion_optimizer::optimizer::OptimizerRule //! [`AnalyzerRule`]: datafusion_optimizer::analyzer::AnalyzerRule -//! [`PhysicalOptimizerRule`]: crate::physical_optimizer::optimizer::PhysicalOptimizerRule +//! [`PhysicalOptimizerRule`]: crate::physical_optimizer::PhysicalOptimizerRule //! //! ## Query Planning and Execution Overview //! diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 66067d8cb5c4..e7580d3e33ef 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -18,7 +18,6 @@ //! Utilizing exact statistics from sources to avoid scanning data use std::sync::Arc; -use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_plan::aggregates::AggregateExec; @@ -29,6 +28,7 @@ use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -429,6 +429,7 @@ pub(crate) mod tests { self.column_name(), false, false, + false, ) .unwrap() } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 42b7463600dc..da0e44c8de4e 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use crate::{ config::ConfigOptions, error::Result, - physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, repartition::RepartitionExec, Partitioning, @@ -31,6 +30,7 @@ use crate::{ }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 940b09131a77..ddb7d36fb595 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -21,13 +21,13 @@ use std::sync::Arc; use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::{physical_exprs_equal, AggregateExpr, PhysicalExpr}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; /// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs /// into a Single AggregateExec if their grouping exprs and aggregate exprs equal. @@ -288,6 +288,7 @@ mod tests { name, false, false, + false, ) .unwrap() } @@ -354,7 +355,7 @@ mod tests { PhysicalGroupBy::default(), aggr_expr, ); - // should combine the Partial/Final AggregateExecs to tne Single AggregateExec + // should combine the Partial/Final AggregateExecs to the Single AggregateExec let expected = &[ "AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", @@ -378,6 +379,7 @@ mod tests { "Sum(b)", false, false, + false, )?]; let groups: Vec<(Arc, String)> = vec![(col("c", &schema)?, "c".to_string())]; @@ -394,7 +396,7 @@ mod tests { let final_group_by = PhysicalGroupBy::new_single(groups); let plan = final_aggregate_exec(partial_agg, final_group_by, aggr_expr); - // should combine the Partial/Final AggregateExecs to tne Single AggregateExec + // should combine the Partial/Final AggregateExecs to the Single AggregateExec let expected = &[ "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index f7c2aee578ba..62ac9089e2b4 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -24,14 +24,12 @@ use std::fmt::Debug; use std::sync::Arc; -use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ add_sort_above_with_check, is_coalesce_partitions, is_repartition, is_sort_preserving_merge, }; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::joins::{ @@ -56,6 +54,8 @@ use datafusion_physical_expr::{ use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; /// The `EnforceDistribution` rule ensures that distribution requirements are @@ -392,7 +392,7 @@ fn adjust_input_keys_ordering( let expr = proj.expr(); // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements - // Construct a mapping from new name to the orginal Column + // Construct a mapping from new name to the original Column let new_required = map_columns_before_projection(&requirements.data, expr); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; @@ -566,7 +566,7 @@ fn shift_right_required( }) .collect::>(); - // if the parent required are all comming from the right side, the requirements can be pushdown + // if the parent required are all coming from the right side, the requirements can be pushdown (new_right_required.len() == parent_required.len()).then_some(new_right_required) } @@ -1290,7 +1290,6 @@ pub(crate) mod tests { use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig, ParquetExec}; use crate::physical_optimizer::enforce_sorting::EnforceSorting; - use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, repartition_exec, }; @@ -1301,6 +1300,7 @@ pub(crate) mod tests { use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; + use datafusion_physical_optimizer::output_requirements::OutputRequirements; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; @@ -1472,6 +1472,7 @@ pub(crate) mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -1496,6 +1497,7 @@ pub(crate) mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -3770,6 +3772,7 @@ pub(crate) mod tests { b'"', None, None, + false, compression_type, )), vec![("a".to_string(), "a".to_string())], diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 24306647c686..e577c5336086 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -49,7 +49,6 @@ use crate::physical_optimizer::utils::{ is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -66,6 +65,7 @@ use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the @@ -631,6 +631,7 @@ mod tests { use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + use datafusion_physical_optimizer::PhysicalOptimizerRule; use rstest::rstest; fn create_test_schema() -> Result { diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 1613e5089860..b849df88e4aa 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -27,7 +27,6 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, @@ -42,6 +41,7 @@ use datafusion_common::{internal_err, JoinSide, JoinType}; use datafusion_expr::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to @@ -719,7 +719,7 @@ mod tests_statistical { use rstest::rstest; - /// Return statistcs for empty table + /// Return statistics for empty table fn empty_statistics() -> Statistics { Statistics { num_rows: Precision::Absent, @@ -737,7 +737,7 @@ mod tests_statistical { ) } - /// Return statistcs for small table + /// Return statistics for small table fn small_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { @@ -747,7 +747,7 @@ mod tests_statistical { } } - /// Return statistcs for big table + /// Return statistics for big table fn big_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { @@ -757,7 +757,7 @@ mod tests_statistical { } } - /// Return statistcs for big table + /// Return statistics for big table fn bigger_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index f9d5a4c186ee..b5d3f432d84d 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -20,7 +20,6 @@ use std::sync::Arc; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::AggregateExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -29,6 +28,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 9ad05bf496e5..7c508eeef878 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -29,17 +29,16 @@ pub mod enforce_sorting; pub mod join_selection; pub mod limited_distinct_aggregation; pub mod optimizer; -pub mod output_requirements; pub mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; pub mod sanity_checker; -mod sort_pushdown; +#[cfg(test)] +pub mod test_utils; pub mod topk_aggregation; pub mod update_aggr_exprs; -mod utils; -#[cfg(test)] -pub mod test_utils; +mod sort_pushdown; +mod utils; -pub use optimizer::PhysicalOptimizerRule; +pub use datafusion_physical_optimizer::*; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 2d9744ad23dd..6449dbea0ddf 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -17,11 +17,11 @@ //! Physical optimizer traits +use datafusion_physical_optimizer::PhysicalOptimizerRule; use std::sync::Arc; use super::projection_pushdown::ProjectionPushdown; use super::update_aggr_exprs::OptimizeAggregateOrder; -use crate::config::ConfigOptions; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; @@ -32,32 +32,6 @@ use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggr use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::sanity_checker::SanityCheckPlan; use crate::physical_optimizer::topk_aggregation::TopKAggregation; -use crate::{error::Result, physical_plan::ExecutionPlan}; - -/// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which -/// computes the same results, but in a potentially more efficient way. -/// -/// Use [`SessionState::add_physical_optimizer_rule`] to register additional -/// `PhysicalOptimizerRule`s. -/// -/// [`SessionState::add_physical_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_physical_optimizer_rule -pub trait PhysicalOptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result>; - - /// A human readable name for this optimizer rule - fn name(&self) -> &str; - - /// A flag to indicate whether the physical planner should valid the rule will not - /// change the schema of the plan after the rewriting. - /// Some of the optimization rules might change the nullable properties of the schema - /// and should disable the schema check. - fn schema_check(&self) -> bool; -} /// A rule-based physical optimizer. #[derive(Clone)] diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 3c2be59f7504..d0d0c985b8b6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -18,13 +18,12 @@ //! This file implements the `ProjectionPushdown` physical optimization rule. //! The function [`remove_unnecessary_projections`] tries to push down all //! projections one by one if the operator below is amenable to this. If a -//! projection reaches a source, it can even dissappear from the plan entirely. +//! projection reaches a source, it can even disappear from the plan entirely. use std::collections::HashMap; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; -use super::PhysicalOptimizerRule; use crate::datasource::physical_plan::CsvExec; use crate::error::Result; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -55,6 +54,7 @@ use datafusion_physical_expr::{ use datafusion_physical_plan::streaming::StreamingTableExec; use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::Itertools; /// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to @@ -186,6 +186,7 @@ fn try_swapping_with_csv( csv.quote(), csv.escape(), csv.comment(), + csv.newlines_in_values(), csv.file_compression_type, )) as _ }) @@ -1700,6 +1701,7 @@ mod tests { 0, None, None, + false, FileCompressionType::UNCOMPRESSED, )) } @@ -1723,6 +1725,7 @@ mod tests { 0, None, None, + false, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index a7ce29bdc7e3..3c18e53497fd 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -609,6 +609,8 @@ impl PruningPredicate { /// /// This happens if the predicate is a literal `true` and /// literal_guarantees is empty. + /// + /// This can happen when a predicate is simplified to a constant `true` pub fn always_true(&self) -> bool { is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } @@ -623,7 +625,7 @@ impl PruningPredicate { /// /// This is useful to avoid fetching statistics for columns that will not be /// used in the predicate. For example, it can be used to avoid reading - /// uneeded bloom filters (a non trivial operation). + /// unneeded bloom filters (a non trivial operation). pub fn literal_columns(&self) -> Vec { let mut seen = HashSet::new(); self.literal_guarantees @@ -736,12 +738,25 @@ impl RequiredColumns { Self::default() } - /// Returns number of unique columns - pub(crate) fn n_columns(&self) -> usize { - self.iter() - .map(|(c, _s, _f)| c) - .collect::>() - .len() + /// Returns Some(column) if this is a single column predicate. + /// + /// Returns None if this is a multi-column predicate. + /// + /// Examples: + /// * `a > 5 OR a < 10` returns `Some(a)` + /// * `a > 5 OR b < 10` returns `None` + /// * `true` returns None + pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + if self.columns.windows(2).all(|w| { + // check if all columns are the same (ignoring statistics and field) + let c1 = &w[0].0; + let c2 = &w[1].0; + c1 == c2 + }) { + self.columns.first().map(|r| &r.0) + } else { + None + } } /// Returns an iterator over items in columns (see doc on diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 013155b8400a..6565e3e7d0d2 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -1503,6 +1503,7 @@ mod tests { b'"', None, None, + false, FileCompressionType::UNCOMPRESSED, )) } diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs index 083b42f7400b..6e37c3f40ffa 100644 --- a/datafusion/core/src/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -24,7 +24,6 @@ use std::sync::Arc; use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::ExecutionPlan; use datafusion_common::config::{ConfigOptions, OptimizerOptions}; @@ -34,6 +33,7 @@ use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supp use datafusion_physical_plan::joins::SymmetricHashJoinExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::izip; /// The SanityCheckPlan rule rejects the following query plans: diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index b754ee75ef3e..82cf44ad7796 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -19,7 +19,6 @@ use std::sync::Arc; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::AggregateExec; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; use crate::physical_plan::filter::FilterExec; @@ -34,6 +33,7 @@ use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs index 1ad4179cefd8..f8edf73e3d2a 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -20,14 +20,13 @@ use std::sync::Arc; -use super::PhysicalOptimizerRule; - use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr::{ reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalSortRequirement, }; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::concat_slices; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ @@ -128,7 +127,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { /// # Returns /// /// Returns `Ok(converted_aggr_exprs)` if the conversion process completes -/// successfully. Any errors occuring during the conversion process are +/// successfully. Any errors occurring during the conversion process are /// passed through. fn try_convert_aggregate_if_better( aggr_exprs: Vec>, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index efc83d8f6b5c..329d343f13fc 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -38,7 +38,6 @@ use crate::logical_expr::{ }; use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; -use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::empty::EmptyExec; @@ -91,6 +90,7 @@ use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; @@ -1872,22 +1872,27 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; + let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr( + + let agg_expr = udaf::create_aggregate_expr_with_dfschema( fun, &physical_args, args, &sort_exprs, &ordering_reqs, - physical_input_schema, + logical_input_schema, name, ignore_nulls, *distinct, + false, )?; + (agg_expr, filter, physical_sort_exprs) } }; + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index e8550a79cb0e..5cb1b6ea7017 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -99,6 +99,7 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result SchemaRef { Arc::new(Schema::new(vec![ @@ -1087,3 +1088,24 @@ async fn test_fn_array_to_string() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_fn_map() -> Result<()> { + let expr = map( + vec![lit("a"), lit("b"), lit("c")], + vec![lit(1), lit(2), lit(3)], + ); + let expected = [ + "+---------------------------------------------------------------------------------------+", + "| map(make_array(Utf8(\"a\"),Utf8(\"b\"),Utf8(\"c\")),make_array(Int32(1),Int32(2),Int32(3))) |", + "+---------------------------------------------------------------------------------------+", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "+---------------------------------------------------------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1b2a6770cf01..bc01ada1e04b 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -99,7 +99,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 - // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here + // for compare difference between sql and df logical plan, we need to create a new SessionContext here let ctx = create_join_context()?; let df_results = ctx .table("t1") @@ -1389,7 +1389,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; @@ -1973,7 +1973,7 @@ async fn test_array_agg() -> Result<()> { let expected = [ "+-------------------------------------+", - "| ARRAY_AGG(test.a) |", + "| array_agg(test.a) |", "+-------------------------------------+", "| [abcDEF, abc123, CBAdef, 123AbcDef] |", "+-------------------------------------+", diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 1df97b1636c7..6efbb9b029de 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -6,7 +6,7 @@ // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -//http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an @@ -16,38 +16,37 @@ // under the License. //! This test demonstrates the DataFusion FIFO capabilities. -//! + #[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use datafusion_common::instant::Instant; - use std::fs::{File, OpenOptions}; - use std::io::Write; + use std::fs::File; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; - use std::thread; use std::time::Duration; use arrow::array::Array; use arrow::csv::ReaderBuilder; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SchemaRef; - use futures::StreamExt; - use nix::sys::stat; - use nix::unistd; - use tempfile::TempDir; - use tokio::task::{spawn_blocking, JoinHandle}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; - use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_common::instant::Instant; + use datafusion_common::{exec_err, Result}; use datafusion_expr::Expr; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; + use tempfile::TempDir; + use tokio::io::AsyncWriteExt; + use tokio::task::{spawn_blocking, JoinHandle}; + /// Makes a TableProvider for a fifo file fn fifo_table( schema: SchemaRef, @@ -71,8 +70,8 @@ mod unix_test { } } - fn write_to_fifo( - mut file: &File, + async fn write_to_fifo( + file: &mut tokio::fs::File, line: &str, ref_time: Instant, broken_pipe_timeout: Duration, @@ -80,11 +79,11 @@ mod unix_test { // We need to handle broken pipe error until the reader is ready. This // is why we use a timeout to limit the wait duration for the reader. // If the error is different than broken pipe, we fail immediately. - while let Err(e) = file.write_all(line.as_bytes()) { + while let Err(e) = file.write_all(line.as_bytes()).await { if e.raw_os_error().unwrap() == 32 { let interval = Instant::now().duration_since(ref_time); if interval < broken_pipe_timeout { - thread::sleep(Duration::from_millis(100)); + tokio::time::sleep(Duration::from_millis(50)).await; continue; } } @@ -93,28 +92,38 @@ mod unix_test { Ok(()) } - fn create_writing_thread( + /// This function creates a writing task for the FIFO file. To verify + /// incremental processing, it waits for a signal to continue writing after + /// a certain number of lines are written. + #[allow(clippy::disallowed_methods)] + fn create_writing_task( file_path: PathBuf, header: String, lines: Vec, - waiting_lock: Arc, - wait_until: usize, + waiting_signal: Arc, + send_before_waiting: usize, ) -> JoinHandle<()> { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); - let sa = file_path.clone(); - // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - spawn_blocking(move || { - let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Spawn a new task to write to the FIFO file + tokio::spawn(async move { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(file_path) + .await + .unwrap(); // Reference time to use when deciding to fail the test let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + write_to_fifo(&mut file, &header, execution_start, broken_pipe_timeout) + .await + .unwrap(); for (cnt, line) in lines.iter().enumerate() { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); + while waiting_signal.load(Ordering::SeqCst) && cnt > send_before_waiting { + tokio::time::sleep(Duration::from_millis(50)).await; } - write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + write_to_fifo(&mut file, line, execution_start, broken_pipe_timeout) + .await + .unwrap(); } drop(file); }) @@ -125,6 +134,8 @@ mod unix_test { const TEST_BATCH_SIZE: usize = 20; // Number of lines written to FIFO const TEST_DATA_SIZE: usize = 20_000; + // Number of lines to write before waiting to verify incremental processing + const SEND_BEFORE_WAITING: usize = 2 * TEST_BATCH_SIZE; // Number of lines what can be joined. Each joinable key produced 20 lines with // aggregate_test_100 dataset. We will use these joinable keys for understanding // incremental execution. @@ -132,7 +143,7 @@ mod unix_test { // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + #[tokio::test] async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() @@ -162,8 +173,8 @@ mod unix_test { .zip(0..TEST_DATA_SIZE) .map(|(a1, a2)| format!("{a1},{a2}\n")) .collect::>(); - // Create writing threads for the left and right FIFO files - let task = create_writing_thread( + // Create writing tasks for the left and right FIFO files + let task = create_writing_task( fifo_path.clone(), "a1,a2\n".to_owned(), lines, @@ -190,7 +201,16 @@ mod unix_test { ) .await?; // Execute the query - let df = ctx.sql("SELECT t1.a2, t2.c1, t2.c4, t2.c5 FROM left as t1 JOIN right as t2 ON t1.a1 = t2.c1").await?; + let df = ctx + .sql( + "SELECT + t1.a2, t2.c1, t2.c4, t2.c5 + FROM + left as t1, right as t2 + WHERE + t1.a1 = t2.c1", + ) + .await?; let mut stream = df.execute_stream().await?; while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); @@ -199,16 +219,9 @@ mod unix_test { Ok(()) } - #[derive(Debug, PartialEq)] - enum JoinOperation { - LeftUnmatched, - RightUnmatched, - Equal, - } - - // This test provides a relatively realistic end-to-end scenario where - // we change the join into a [SymmetricHashJoin] to accommodate two - // unbounded (FIFO) sources. + /// This test provides a relatively realistic end-to-end scenario where + /// we change the join into a `SymmetricHashJoinExec` to accommodate two + /// unbounded (FIFO) sources. #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context @@ -247,19 +260,18 @@ mod unix_test { let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 + t1.a1, t1.a2, t2.a1, t2.a2 FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + left as t1 + FULL JOIN + right as t2 + ON + t1.a2 = t2.a2 AND + t1.a1 > t2.a1 + 4 AND + t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; - let mut operations = vec![]; // Tasks let mut tasks: Vec> = vec![]; @@ -273,54 +285,43 @@ mod unix_test { .map(|(a1, a2)| format!("{a1},{a2}\n")) .collect::>(); - // Create writing threads for the left and right FIFO files - tasks.push(create_writing_thread( + // Create writing tasks for the left and right FIFO files + tasks.push(create_writing_task( left_fifo, "a1,a2\n".to_owned(), lines.clone(), waiting.clone(), - TEST_BATCH_SIZE, + SEND_BEFORE_WAITING, )); - tasks.push(create_writing_thread( + tasks.push(create_writing_task( right_fifo, "a1,a2\n".to_owned(), - lines.clone(), + lines, waiting.clone(), - TEST_BATCH_SIZE, + SEND_BEFORE_WAITING, )); - // Partial. + // Collect output data: + let (mut equal, mut left, mut right) = (0, 0, 0); while let Some(Ok(batch)) = stream.next().await { waiting.store(false, Ordering::SeqCst); let left_unmatched = batch.column(2).null_count(); let right_unmatched = batch.column(0).null_count(); - let op = if left_unmatched == 0 && right_unmatched == 0 { - JoinOperation::Equal - } else if right_unmatched > left_unmatched { - JoinOperation::RightUnmatched + if left_unmatched == 0 && right_unmatched == 0 { + equal += 1; + } else if right_unmatched <= left_unmatched { + left += 1; } else { - JoinOperation::LeftUnmatched + right += 1; }; - operations.push(op); } futures::future::try_join_all(tasks).await.unwrap(); - // The SymmetricHashJoin executor produces FULL join results at every - // pruning, which happens before it reaches the end of input and more - // than once. In this test, we feed partially joinable data to both - // sides in order to ensure that left or right unmatched results are - // generated more than once during the test. - assert!( - operations - .iter() - .filter(|&n| JoinOperation::RightUnmatched.eq(n)) - .count() - > 1 - && operations - .iter() - .filter(|&n| JoinOperation::LeftUnmatched.eq(n)) - .count() - > 1 - ); + // The symmetric hash join algorithm produces FULL join results at + // every pruning, which happens before it reaches the end of input and + // more than once. In this test, we feed partially joinable data to + // both sides in order to ensure that left or right unmatched results + // are generated as expected. + assert!(equal >= 0 && left > 1 && right > 1); Ok(()) } @@ -341,17 +342,14 @@ mod unix_test { (source_fifo_path.clone(), source_fifo_path.display()); // Tasks let mut tasks: Vec> = vec![]; - // TEST_BATCH_SIZE + 1 rows will be provided. However, after processing precisely - // TEST_BATCH_SIZE rows, the program will pause and wait for a batch to be read in another - // thread. This approach ensures that the pipeline remains unbroken. - tasks.push(create_writing_thread( + tasks.push(create_writing_task( source_fifo_path_thread, "a1,a2\n".to_owned(), (0..TEST_DATA_SIZE) .map(|_| "a,1\n".to_string()) .collect::>(), waiting, - TEST_BATCH_SIZE, + SEND_BEFORE_WAITING, )); // Create a new temporary FIFO file let sink_fifo_path = create_fifo_file(&tmp_dir, "sink.csv")?; @@ -370,8 +368,8 @@ mod unix_test { let mut reader = ReaderBuilder::new(schema) .with_batch_size(TEST_BATCH_SIZE) + .with_header(true) .build(file) - .map_err(|e| DataFusionError::Internal(e.to_string())) .unwrap(); while let Some(Ok(_)) = reader.next() { diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index a04f4f349122..736560da97db 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -113,6 +113,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str "sum1", false, false, + false, ) .unwrap()]; let expr = group_by_columns diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 17dbf3a0ff28..f1cca66712d7 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -83,7 +83,7 @@ fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> Joi } fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { - let less_than_100 = Arc::new(BinaryExpr::new( + let less_filter = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 1)), Operator::Lt, Arc::new(Column::new("x", 0)), @@ -99,11 +99,19 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { }, ]; let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("x").unwrap().to_owned(), - schema2.field_with_name("x").unwrap().to_owned(), + schema1 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), + schema2 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), ]); - JoinFilter::new(less_than_100, column_indices, intermediate_schema) + JoinFilter::new(less_filter, column_indices, intermediate_schema) } #[tokio::test] @@ -217,6 +225,8 @@ async fn test_semi_join_1k() { #[tokio::test] async fn test_semi_join_1k_filtered() { + // NLJ vs HJ gives wrong result + // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -239,17 +249,20 @@ async fn test_anti_join_1k() { .await } -// Test failed for now. https://github.com/apache/datafusion/issues/10872 -#[ignore] #[tokio::test] +#[ignore] +// flaky test giving 1 rows difference sometimes +// https://github.com/apache/datafusion/issues/11555 async fn test_anti_join_1k_filtered() { + // NLJ vs HJ gives wrong result + // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj], false) .await } @@ -422,12 +435,13 @@ impl JoinFuzzTestCase { let session_config = SessionConfig::new().with_batch_size(*batch_size); let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); - let smj = self.sort_merge_join(); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); let hj = self.hash_join(); let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + let nlj = self.nested_loop_join(); let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); @@ -437,11 +451,12 @@ impl JoinFuzzTestCase { let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); if debug { - println!("The debug is ON. Input data will be saved"); let fuzz_debug = "fuzz_test_debug"; std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); + println!("The debug is ON. Input data will be saved to {out_dir_name}"); + Self::save_partitioned_batches_as_parquet( &self.input1, out_dir_name, @@ -562,8 +577,7 @@ impl JoinFuzzTestCase { /// Some(Box::new(col_lt_col_filter)), /// ) /// .run_test(&[JoinTestType::HjSmj], false) - /// .await - /// } + /// .await; /// /// let ctx: SessionContext = SessionContext::new(); /// let df = ctx @@ -592,6 +606,7 @@ impl JoinFuzzTestCase { /// ) /// .run_test() /// .await + /// } fn save_partitioned_batches_as_parquet( input: &[RecordBatch], output_dir: &str, diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 1d151f9fd368..bc2c3315da59 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -164,7 +164,7 @@ async fn cross_join() { } #[tokio::test] -async fn merge_join() { +async fn sort_merge_join_no_spill() { // Planner chooses MergeJoin only if number of partitions > 1 let config = SessionConfig::new() .with_target_partitions(2) @@ -175,11 +175,32 @@ async fn merge_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", + "Failed to allocate additional", "SMJStream", + "Disk spilling disabled", ]) .with_memory_limit(1_000) .with_config(config) + .with_scenario(Scenario::AccessLogStreaming) + .run() + .await +} + +#[tokio::test] +async fn sort_merge_join_spill() { + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_memory_limit(1_000) + .with_config(config) + .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_scenario(Scenario::AccessLogStreaming) .run() .await } @@ -259,7 +280,7 @@ async fn sort_spill_reservation() { .with_query("select * from t ORDER BY a , b DESC") // enough memory to sort if we don't try to merge it all at once .with_memory_limit(partition_size) - // use a single partiton so only a sort is needed + // use a single partition so only a sort is needed .with_scenario(scenario) .with_disk_manager_config(DiskManagerConfig::NewOs) .with_expected_plan( @@ -361,7 +382,7 @@ struct TestCase { /// How should the disk manager (that allows spilling) be /// configured? Defaults to `Disabled` disk_manager_config: DiskManagerConfig, - /// Expected explain plan, if non emptry + /// Expected explain plan, if non-empty expected_plan: Vec, /// Is the plan expected to pass? Defaults to false expected_success: bool, @@ -453,7 +474,7 @@ impl TestCase { let table = scenario.table(); let rt_config = RuntimeConfig::new() - // do not allow spilling + // disk manager setting controls the spilling .with_disk_manager(disk_manager_config) .with_memory_limit(memory_limit, MEMORY_FRACTION); diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 2b4ba0b17133..623f321ce152 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -579,7 +579,7 @@ async fn test_data_page_stats_with_all_null_page() { /////////////// MORE GENERAL TESTS ////////////////////// // . Many columns in a file -// . Differnet data types +// . Different data types // . Different row group sizes // Four different integer types @@ -1733,7 +1733,7 @@ async fn test_float16() { #[tokio::test] async fn test_decimal() { - // This creates a parquet file of 1 column "decimal_col" with decimal data type and precicion 9, scale 2 + // This creates a parquet file of 1 column "decimal_col" with decimal data type and precision 9, scale 2 // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups let reader = TestReader { scenario: Scenario::Decimal, @@ -1763,7 +1763,7 @@ async fn test_decimal() { } #[tokio::test] async fn test_decimal_256() { - // This creates a parquet file of 1 column "decimal256_col" with decimal data type and precicion 9, scale 2 + // This creates a parquet file of 1 column "decimal256_col" with decimal data type and precision 9, scale 2 // file has 3 record batches, each has 5 rows. They will be saved into 3 row groups let reader = TestReader { scenario: Scenario::Decimal256, diff --git a/datafusion/core/tests/parquet/schema.rs b/datafusion/core/tests/parquet/schema.rs index 1b572914d7bd..e13fbad24426 100644 --- a/datafusion/core/tests/parquet/schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -25,7 +25,7 @@ use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { - // Create several parquet files in same directoty / table with + // Create several parquet files in same directory / table with // same schema but different metadata let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); @@ -103,7 +103,7 @@ async fn schema_merge_ignores_metadata_by_default() { #[tokio::test] async fn schema_merge_can_preserve_metadata() { - // Create several parquet files in same directoty / table with + // Create several parquet files in same directory / table with // same schema but different metadata let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 1f4f9e77d5dc..1f10cb244e83 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -35,7 +35,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { assert_eq!( *actual[0].schema(), Schema::new(vec![Field::new_list( - "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", + "array_agg(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), true ),]) diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 502590f9e2e2..fe4777b04396 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -352,7 +352,7 @@ async fn csv_explain_verbose() { // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); - // Don't actually test the contents of the debuging output (as + // Don't actually test the contents of the debugging output (as // that may change and keeping this test updated will be a // pain). Instead just check for a few key pieces. assert_contains!(&actual, "logical_plan"); diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 5847952ae6a6..219f6c26cf8f 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -941,11 +941,11 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { /// Saves whatever is passed to it as a scalar function #[derive(Debug, Default)] -struct RecordingFunctonFactory { +struct RecordingFunctionFactory { calls: Mutex>, } -impl RecordingFunctonFactory { +impl RecordingFunctionFactory { fn new() -> Self { Self::default() } @@ -957,7 +957,7 @@ impl RecordingFunctonFactory { } #[async_trait::async_trait] -impl FunctionFactory for RecordingFunctonFactory { +impl FunctionFactory for RecordingFunctionFactory { async fn create( &self, _state: &SessionState, @@ -972,7 +972,7 @@ impl FunctionFactory for RecordingFunctonFactory { #[tokio::test] async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<()> { - let function_factory = Arc::new(RecordingFunctonFactory::new()); + let function_factory = Arc::new(RecordingFunctionFactory::new()); let ctx = SessionContext::new().with_function_factory(function_factory.clone()); let sql = r#" diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 3f66a304dc18..92ed1b2918de 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy]` for //! help with allocation accounting. -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use std::{cmp::Ordering, sync::Arc}; mod pool; @@ -220,6 +220,23 @@ impl MemoryReservation { self.size = new_size } + /// Tries to free `capacity` bytes from this reservation + /// if `capacity` does not exceed [`Self::size`] + /// Returns new reservation size + /// or error if shrinking capacity is more than allocated size + pub fn try_shrink(&mut self, capacity: usize) -> Result { + if let Some(new_size) = self.size.checked_sub(capacity) { + self.registration.pool.shrink(self, capacity); + self.size = new_size; + Ok(new_size) + } else { + internal_err!( + "Cannot free the capacity {capacity} out of allocated size {}", + self.size + ) + } + } + /// Sets the size of this reservation to `capacity` pub fn resize(&mut self, capacity: usize) { match capacity.cmp(&self.size) { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 3cae78eaed9b..4037e3c5db9b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -17,13 +17,12 @@ //! Aggregate function module contains all built-in aggregate functions definitions -use std::sync::Arc; use std::{fmt, str::FromStr}; use crate::utils; use crate::{type_coercion::aggregates::*, Signature, Volatility}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use strum_macros::EnumIter; @@ -37,8 +36,6 @@ pub enum AggregateFunction { Min, /// Maximum Max, - /// Aggregation into an array - ArrayAgg, } impl AggregateFunction { @@ -47,7 +44,6 @@ impl AggregateFunction { match self { Min => "MIN", Max => "MAX", - ArrayAgg => "ARRAY_AGG", } } } @@ -65,7 +61,6 @@ impl FromStr for AggregateFunction { // general "max" => AggregateFunction::Max, "min" => AggregateFunction::Min, - "array_agg" => AggregateFunction::ArrayAgg, _ => { return plan_err!("There is no built-in function named {name}"); } @@ -80,7 +75,7 @@ impl AggregateFunction { pub fn return_type( &self, input_expr_types: &[DataType], - input_expr_nullable: &[bool], + _input_expr_nullable: &[bool], ) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -105,11 +100,6 @@ impl AggregateFunction { // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } - AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( - "item", - coerced_data_types[0].clone(), - input_expr_nullable[0], - )))), } } @@ -118,7 +108,6 @@ impl AggregateFunction { pub fn nullable(&self) -> Result { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), - AggregateFunction::ArrayAgg => Ok(true), } } } @@ -128,7 +117,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() @@ -152,8 +140,8 @@ mod tests { use strum::IntoEnumIterator; #[test] - // Test for AggregateFuncion's Display and from_str() implementations. - // For each variant in AggregateFuncion, it converts the variant to a string + // Test for AggregateFunction's Display and from_str() implementations. + // For each variant in AggregateFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and // the reconstructed variant are the same. This assertion is also necessary for // function suggestion. See https://github.com/apache/datafusion/issues/8082 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a344e621ddb1..452c05be34f4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -109,7 +109,7 @@ use sqlparser::ast::NullTreatment; /// ## Binary Expressions /// /// Exprs implement traits that allow easy to understand construction of more -/// complex expresions. For example, to create `c1 + c2` to add columns "c1" and +/// complex expressions. For example, to create `c1 + c2` to add columns "c1" and /// "c2" together /// /// ``` @@ -1340,6 +1340,9 @@ impl Expr { /// /// returns `None` if the expression is not a `Column` /// + /// Note: None may be returned for expressions that are not `Column` but + /// are convertible to `Column` such as `Cast` expressions. + /// /// Example /// ``` /// # use datafusion_common::Column; @@ -1358,6 +1361,23 @@ impl Expr { } } + /// Returns the inner `Column` if any. This is a specialized version of + /// [`Self::try_as_col`] that take Cast expressions into account when the + /// expression is as on condition for joins. + /// + /// Called this method when you are sure that the expression is a `Column` + /// or a `Cast` expression that wraps a `Column`. + pub fn get_as_join_column(&self) -> Option<&Column> { + match self { + Expr::Column(c) => Some(c), + Expr::Cast(Cast { expr, .. }) => match &**expr { + Expr::Column(c) => Some(c), + _ => None, + }, + _ => None, + } + } + /// Return all referenced columns of this expression. #[deprecated(since = "40.0.0", note = "use Expr::column_refs instead")] pub fn to_columns(&self) -> Result> { @@ -1398,7 +1418,7 @@ impl Expr { } Ok(TreeNodeRecursion::Continue) }) - .expect("traversal is infallable"); + .expect("traversal is infallible"); } /// Return all references to columns and their occurrence counts in the expression. @@ -1433,7 +1453,7 @@ impl Expr { } Ok(TreeNodeRecursion::Continue) }) - .expect("traversal is infallable"); + .expect("traversal is infallible"); } /// Returns true if there are any column references in this Expr diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0213fd52fd..9187e8352205 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -171,18 +171,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 91bec501f4a0..8d460bdc8e7d 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -155,7 +155,7 @@ pub fn unnormalize_col(expr: Expr) -> Expr { }) }) .data() - .expect("Unnormalize is infallable") + .expect("Unnormalize is infallible") } /// Create a Column from the Scalar Expr @@ -201,7 +201,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { }) }) .data() - .expect("strip_outer_reference is infallable") + .expect("strip_outer_reference is infallible") } /// Returns plan with expressions coerced to types compatible with diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1df5d6c4d736..5e0571f712ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -112,7 +112,17 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + for (_, then_expr) in &case.when_then_expr { + let then_type = then_expr.get_type(schema)?; + if !then_type.is_null() { + return Ok(then_type); + } + } + case.else_expr + .as_ref() + .map_or(Ok(DataType::Null), |e| e.get_type(schema)) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::Unnest(Unnest { expr }) => { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 73ab51494de6..d722e55de487 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -20,7 +20,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::Result; +use datafusion_common::{DFSchema, Result}; use std::sync::Arc; #[derive(Debug, Clone, Copy)] @@ -57,6 +57,9 @@ pub struct AccumulatorArgs<'a> { /// The schema of the input arguments pub schema: &'a Schema, + /// The schema of the input arguments + pub dfschema: &'a DFSchema, + /// Whether to ignore nulls. /// /// SQL allows the user to specify `IGNORE NULLS`, for example: @@ -78,6 +81,9 @@ pub struct AccumulatorArgs<'a> { /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + /// Whether the aggregation is running in reverse order + pub is_reversed: bool, + /// The name of the aggregate expression pub name: &'a str, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4ad3bd5018a4..98e262f0b187 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -412,14 +412,14 @@ impl LogicalPlanBuilder { /// Add missing sort columns to all downstream projection /// - /// Thus, if you have a LogialPlan that selects A and B and have + /// Thus, if you have a LogicalPlan that selects A and B and have /// not requested a sort by C, this code will add C recursively to /// all input projections. /// /// Adding a new column is not correct if there is a `Distinct` /// node, which produces only distinct values of its /// inputs. Adding a new column to its input will result in - /// potententially different results than with the original column. + /// potentially different results than with the original column. /// /// For example, if the input is like: /// @@ -1763,7 +1763,7 @@ mod tests { .unwrap(); assert_eq!(&expected, plan.schema().as_ref()); - // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifer + // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well let projection = None; let plan = diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 81fd03555abb..343eda056ffe 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -338,9 +338,9 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { .collect::>() .join(", "); - let elipse = if values.len() > 5 { "..." } else { "" }; + let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, elipse); + let values_str = format!("{}{}", str_values, eclipse); json!({ "Node Type": "Values", "Values": values_str diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index bde9655b8a39..d4fe233cac06 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -263,7 +263,7 @@ pub enum LogicalPlan { /// Prepare a statement and find any bind parameters /// (e.g. `?`). This is used to implement SQL-prepared statements. Prepare(Prepare), - /// Data Manipulaton Language (DML): Insert / Update / Delete + /// Data Manipulation Language (DML): Insert / Update / Delete Dml(DmlStatement), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), @@ -496,18 +496,18 @@ impl LogicalPlan { // The join keys in using-join must be columns. let columns = on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { - let Some(l) = l.try_as_col().cloned() else { + let Some(l) = l.get_as_join_column() else { return internal_err!( "Invalid join key. Expected column, found {l:?}" ); }; - let Some(r) = r.try_as_col().cloned() else { + let Some(r) = r.get_as_join_column() else { return internal_err!( "Invalid join key. Expected column, found {r:?}" ); }; - accumu.insert(l); - accumu.insert(r); + accumu.insert(l.to_owned()); + accumu.insert(r.to_owned()); Result::<_, DataFusionError>::Ok(accumu) })?; using_columns.push(columns); @@ -1598,8 +1598,8 @@ impl LogicalPlan { }) .collect(); - let elipse = if values.len() > 5 { "..." } else { "" }; - write!(f, "Values: {}{}", str_values.join(", "), elipse) + let eclipse = if values.len() > 5 { "..." } else { "" }; + write!(f, "Values: {}{}", str_values.join(", "), eclipse) } LogicalPlan::TableScan(TableScan { diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 04b6faf55ae1..a0f0988b4f4e 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -135,7 +135,7 @@ pub trait PartitionEvaluator: Debug + Send { /// must produce an output column with one output row for every /// input row. /// - /// `num_rows` is requied to correctly compute the output in case + /// `num_rows` is required to correctly compute the output in case /// `values.len() == 0` /// /// Implementing this function is an optimization: certain window diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 009f3512c588..c775427df138 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, @@ -173,6 +173,30 @@ pub trait ExprPlanner: Send + Sync { fn plan_overlay(&self, args: Vec) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plan a make_map expression, e.g., `make_map(key1, value1, key2, value2, ...)` + /// + /// Returns origin expression arguments if not possible + fn plan_make_map(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// + /// Note: + /// Currently compound identifier for outer query schema is not supported. + /// + /// Returns planned expression + fn plan_compound_identifier( + &self, + _field: &Field, + _qualifier: Option<&TableReference>, + _nested_names: &[String], + ) -> Result>> { + not_impl_err!( + "Default planner compound identifier hasn't been implemented for ExprPlanner" + ) + } } /// An operator with two arguments to plan diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index fba793dd229d..eadd7ac2f83f 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -65,7 +65,7 @@ pub enum Volatility { /// automatically coerces (add casts to) function arguments so they match the type signature. /// /// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query -/// that calles `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically +/// that calls `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically /// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. /// /// # Data Types diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index fbec6e2f8024..a024401e18d5 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -95,7 +95,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type // unpack the dictionary to get the value @@ -360,11 +359,7 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types - let funs = vec![ - AggregateFunction::ArrayAgg, - AggregateFunction::Min, - AggregateFunction::Max, - ]; + let funs = vec![AggregateFunction::Min, AggregateFunction::Max]; let input_types = vec![ vec![DataType::Int32], vec![DataType::Decimal128(10, 2)], diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 70139aaa4a0c..e1765b5c3e6a 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -370,7 +370,7 @@ impl From<&DataType> for TypeCategory { /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted -/// decimal percision and scale when coercing decimal types. +/// decimal precision and scale when coercing decimal types. pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; @@ -718,7 +718,7 @@ pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | // Left Float is larger than right Float. (Float32 | Float64, Float16) | (Float64, Float32) | - // Left String is larget than right String. + // Left String is larger than right String. (LargeUtf8, Utf8) | // Any left type is wider than a right hand side Null. (_, Null) => lhs.clone(), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b430b343e484..ef52a01e0598 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -646,7 +646,7 @@ mod tests { vec![DataType::UInt8, DataType::UInt16], Some(vec![DataType::UInt8, DataType::UInt16]), ), - // 2 entries, can coerse values + // 2 entries, can coerce values ( vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt8, DataType::UInt16], diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 86005da3dafa..e0d1236aac2d 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -19,7 +19,7 @@ //! //! Coercion is performed automatically by DataFusion when the types //! of arguments passed to a function or needed by operators do not -//! exacty match the types required by that function / operator. In +//! exactly match the types required by that function / operator. In //! this case, DataFusion will attempt to *coerce* the arguments to //! types accepted by the function by inserting CAST operations. //! diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1657e034fbe2..2851ca811e0c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -542,7 +542,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { pub enum ReversedUDAF { /// The expression is the same as the original expression, like SUM, COUNT Identical, - /// The expression does not support reverse calculation, like ArrayAgg + /// The expression does not support reverse calculation NotSupported, /// The expression is different from the original expression Reversed(Arc), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 45155cbd2c27..889aa0952e51 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1212,7 +1212,7 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { } } -/// Build state name. State is the intermidiate state of the aggregate function. +/// Build state name. State is the intermediate state of the aggregate function. pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index bbe7d21e2486..dfb94a84cbec 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -30,7 +30,8 @@ use arrow::{ use arrow_schema::{Field, Schema}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, + ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; @@ -42,7 +43,7 @@ use datafusion_expr::{ use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr; +use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; make_udaf_expr_and_func!( ApproxPercentileCont, @@ -135,7 +136,9 @@ impl ApproxPercentileCont { fn get_lit_value(expr: &Expr) -> datafusion_common::Result { let empty_schema = Arc::new(Schema::empty()); let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - let expr = limited_convert_logical_expr_to_physical_expr(expr, &empty_schema)?; + let dfschema = DFSchema::empty(); + let expr = + limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, &dfschema)?; let result = expr.evaluate(&empty_batch)?; match result { ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index a64218c606c4..0dbea1fb1ff7 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -43,7 +43,7 @@ make_udaf_expr_and_func!( approx_percentile_cont_with_weight_udaf ); -/// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression +/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/functions-aggregate/src/array_agg.rs similarity index 66% rename from datafusion/physical-expr/src/aggregate/array_agg_ordered.rs rename to datafusion/functions-aggregate/src/array_agg.rs index 992c06f5bf62..96b39ae4121e 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -15,151 +15,294 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions which specify ordering requirement -//! that can evaluated at runtime during query execution +//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] -use std::any::Any; -use std::collections::VecDeque; -use std::fmt::Debug; -use std::sync::Arc; - -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::format_state_name; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; +use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray}; +use arrow::datatypes::DataType; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; -use arrow_schema::Fields; +use arrow_schema::{Field, Fields}; +use datafusion_common::cast::as_list_array; use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::Accumulator; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::{Accumulator, Signature, Volatility}; use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; +use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{ + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, +}; +use std::collections::{HashSet, VecDeque}; +use std::sync::Arc; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "input values, including nulls, concatenated into an array", + array_agg_udaf +); -/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi -/// partition setting, partial aggregations are computed for every partition, -/// and then their results are merged. #[derive(Debug)] -pub struct OrderSensitiveArrayAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, - /// Whether the aggregation is running in reverse - reverse: bool, +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, } -impl OrderSensitiveArrayAgg { - /// Create a new `OrderSensitiveArrayAgg` aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { +impl Default for ArrayAgg { + fn default() -> Self { Self { - name: name.into(), - input_data_type, - expr, - order_by_data_types, - ordering_req, - reverse: false, + signature: Signature::any(1, Volatility::Immutable), } } } -impl AggregateExpr for OrderSensitiveArrayAgg { - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) + fn name(&self) -> &str { + "array_agg" } - fn create_accumulator(&self) -> Result> { - OrderSensitiveArrayAggAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.reverse, - ) - .map(|acc| Box::new(acc) as _) + fn aliases(&self) -> &[String] { + &[] + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + return Ok(vec![Field::new_list( + format_state_name(args.name, "distinct_array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]); + } + let mut fields = vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, // This should be the same as field() + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + true, )]; - let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); + + if args.ordering_fields.is_empty() { + return Ok(fields); + } + + let orderings = args.ordering_fields.to_vec(); fields.push(Field::new_list( - format_state_name(&self.name, "array_agg_orderings"), + format_state_name(args.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), false, )); + Ok(fields) } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return Ok(Box::new(DistinctArrayAggAccumulator::try_new( + acc_args.input_type, + )?)); + } + + if acc_args.sort_exprs.is_empty() { + return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)); + } + + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( + acc_args.sort_exprs, + acc_args.dfschema, + )?; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + OrderSensitiveArrayAggAccumulator::try_new( + acc_args.input_type, + &ordering_dtypes, + ordering_req, + acc_args.is_reversed, + ) + .map(|acc| Box::new(acc) as _) + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } +} + +#[derive(Debug)] +pub struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) } +} + +impl Accumulator for ArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Append value like Int64Array(1,2,3) + if values.is_empty() { + return Ok(()); + } + + if values.len() != 1 { + return internal_err!("expects single batch"); + } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::HardRequirement + let val = Arc::clone(&values[0]); + if val.len() > 0 { + self.values.push(val); + } + Ok(()) } - fn name(&self) -> &str { - &self.name + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array_nullable(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: Arc::clone(&self.expr), - order_by_data_types: self.order_by_data_types.clone(), - // Reverse requirement: - ordering_req: reverse_order_bys(&self.ordering_req), - reverse: !self.reverse, - })) + fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) } } -impl PartialEq for OrderSensitiveArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() != 1 { + return internal_err!("expects single batch"); + } + + let array = &values[0]; + + for i in 0..array.len() { + let scalar = ScalarValue::try_from_array(&array, i)?; + self.values.insert(scalar); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) + } + + fn evaluate(&mut self) -> Result { + let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + let arr = ScalarValue::new_list(&values, &self.datatype, true); + Ok(ScalarValue::List(arr)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - std::mem::size_of_val(&self.values) + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) } } +/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. #[derive(Debug)] pub(crate) struct OrderSensitiveArrayAggAccumulator { /// Stores entries in the `ARRAY_AGG` result. @@ -291,6 +434,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.evaluate()?]; result.push(self.evaluate_orderings()?); + Ok(result) } @@ -375,13 +519,14 @@ impl OrderSensitiveArrayAggAccumulator { #[cfg(test)] mod tests { + use super::*; + use std::collections::VecDeque; use std::sync::Arc; - use crate::aggregate::array_agg_ordered::merge_ordered_arrays; - - use arrow_array::{Array, ArrayRef, Int64Array}; + use arrow::array::Int64Array; use arrow_schema::SortOptions; + use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 9224b06e407a..6c2d6cb5285c 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -358,6 +358,15 @@ where Ok(()) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // XOR is it's own inverse + self.update_batch(values) + } + + fn supports_retract_batch(&self) -> bool { + true + } + fn evaluate(&mut self) -> Result { ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } @@ -456,3 +465,41 @@ where Ok(()) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{ArrayRef, UInt64Array}; + use arrow::datatypes::UInt64Type; + use datafusion_common::ScalarValue; + + use crate::bit_and_or_xor::BitXorAccumulator; + use datafusion_expr::Accumulator; + + #[test] + fn test_bit_xor_accumulator() { + let mut accumulator = BitXorAccumulator:: { value: None }; + let batches: Vec<_> = vec![vec![1, 2], vec![1]] + .into_iter() + .map(|b| Arc::new(b.into_iter().collect::()) as ArrayRef) + .collect(); + + let added = &[Arc::clone(&batches[0])]; + let retracted = &[Arc::clone(&batches[1])]; + + // XOR of 1..3 is 3 + accumulator.update_batch(added).unwrap(); + assert_eq!( + accumulator.evaluate().unwrap(), + ScalarValue::UInt64(Some(3)) + ); + + // Removing [1] ^ 3 = 2 + accumulator.retract_batch(retracted).unwrap(); + assert_eq!( + accumulator.evaluate().unwrap(), + ScalarValue::UInt64(Some(2)) + ); + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 0a667d35dce5..0ead22e90a16 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -44,7 +44,7 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::{ aggregate::count_distinct::{ @@ -95,7 +95,11 @@ impl Default for Count { impl Count { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + // TypeSignature::Any(0) is required to handle `Count()` with no args + vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], + Volatility::Immutable, + ), } } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0e619bacef82..ba11f7e91e07 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -36,7 +36,8 @@ use datafusion_expr::{ }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, }; create_func!(FirstValue, first_value_udaf); @@ -116,9 +117,9 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req @@ -415,9 +416,9 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index a3808a08b007..b39b1955bb07 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,6 +58,7 @@ pub mod macros; pub mod approx_distinct; +pub mod array_agg; pub mod correlation; pub mod count; pub mod covariance; @@ -93,6 +94,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; pub use super::average::avg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; @@ -128,6 +130,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -191,8 +194,9 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermediate migration state, skip them - if func.name().to_lowercase() == "count" { + // These functions are in intermidiate migration state, skip them + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" || name_lower_case == "array_agg" { continue; } assert!( diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 6719c673c55b..9bbd68c9bdf6 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -36,7 +36,8 @@ use datafusion_expr::{ use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; use datafusion_physical_expr_common::aggregate::utils::ordering_fields; use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, + limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, + PhysicalSortExpr, }; make_udaf_expr_and_func!( @@ -111,9 +112,9 @@ impl AggregateUDFImpl for NthValueAgg { ), }?; - let ordering_req = limited_convert_logical_sort_exprs_to_physical( + let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, - acc_args.schema, + acc_args.dfschema, )?; let ordering_dtypes = ordering_req diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 42cf44f65d8f..247962dc2ce1 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -273,6 +273,7 @@ mod tests { use arrow::{array::*, datatypes::*}; + use datafusion_common::DFSchema; use datafusion_expr::AggregateUDF; use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr_common::expressions::column::col; @@ -324,13 +325,16 @@ mod tests { agg2: Arc, schema: &Schema, ) -> Result { + let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { data_type: &DataType::Float64, schema, + dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], name: "a", is_distinct: false, + is_reversed: false, input_type: &DataType::Float64, input_exprs: &[datafusion_expr::col("a")], }; @@ -338,10 +342,12 @@ mod tests { let args2 = AccumulatorArgs { data_type: &DataType::Float64, schema, + dfschema: &dfschema, ignore_nulls: false, sort_exprs: &[], name: "a", is_distinct: false, + is_reversed: false, input_type: &DataType::Float64, input_exprs: &[datafusion_expr::col("a")], }; diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 73c5b9114a2c..de424b259694 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -53,6 +53,7 @@ datafusion-functions-aggregate = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" +rand = "0.8.5" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } @@ -60,3 +61,7 @@ criterion = { version = "0.5", features = ["async_tokio"] } [[bench]] harness = false name = "array_expression" + +[[bench]] +harness = false +name = "map" diff --git a/datafusion/functions/benches/map.rs b/datafusion/functions-array/benches/map.rs similarity index 83% rename from datafusion/functions/benches/map.rs rename to datafusion/functions-array/benches/map.rs index cd863d0e3311..c2e0e641e80d 100644 --- a/datafusion/functions/benches/map.rs +++ b/datafusion/functions-array/benches/map.rs @@ -17,17 +17,20 @@ extern crate criterion; -use arrow::array::{Int32Array, ListArray, StringArray}; -use arrow::datatypes::{DataType, Field}; +use arrow_array::{Int32Array, ListArray, StringArray}; use arrow_buffer::{OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; -use datafusion_functions::core::{make_map, map}; use rand::prelude::ThreadRng; use rand::Rng; use std::sync::Arc; +use datafusion_common::ScalarValue; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_functions_array::map::map_udf; +use datafusion_functions_array::planner::ArrayFunctionPlanner; + fn keys(rng: &mut ThreadRng) -> Vec { let mut keys = vec![]; for _ in 0..1000 { @@ -51,16 +54,16 @@ fn criterion_benchmark(c: &mut Criterion) { let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - keys[i].clone(), - )))); - buffer.push(ColumnarValue::Scalar(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); } + let planner = ArrayFunctionPlanner {}; + b.iter(|| { black_box( - make_map() - .invoke(&buffer) + planner + .plan_make_map(buffer.clone()) .expect("map should work on valid values"), ); }); @@ -89,7 +92,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - map() + map_udf() .invoke(&[keys.clone(), values.clone()]) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 9717d29883fd..f68f59dcd6a1 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -41,6 +41,7 @@ pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod map; pub mod planner; pub mod position; pub mod range; @@ -53,6 +54,7 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; + use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::ScalarUDF; @@ -140,6 +142,7 @@ pub fn all_default_array_functions() -> Vec> { replace::array_replace_n_udf(), replace::array_replace_all_udf(), replace::array_replace_udf(), + map::map_udf(), ] } diff --git a/datafusion/functions/src/core/map.rs b/datafusion/functions-array/src/map.rs similarity index 52% rename from datafusion/functions/src/core/map.rs rename to datafusion/functions-array/src/map.rs index 1834c7ac6060..e218b501dcf1 100644 --- a/datafusion/functions/src/core/map.rs +++ b/datafusion/functions-array/src/map.rs @@ -15,18 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::make_array::make_array; +use arrow::array::ArrayData; +use arrow_array::{Array, ArrayRef, MapArray, StructArray}; +use arrow_buffer::{Buffer, ToByteSlice}; +use arrow_schema::{DataType, Field, SchemaBuilder}; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::collections::VecDeque; use std::sync::Arc; -use arrow::array::{Array, ArrayData, ArrayRef, MapArray, StructArray}; -use arrow::compute::concat; -use arrow::datatypes::{DataType, Field, SchemaBuilder}; -use arrow_buffer::{Buffer, ToByteSlice}; +/// Returns a map created from a key list and a value list +pub fn map(keys: Vec, values: Vec) -> Expr { + let keys = make_array(keys); + let values = make_array(values); + Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values])) +} -use datafusion_common::{exec_err, internal_err, ScalarValue}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +create_func!(MapFunc, map_udf); /// Check if we can evaluate the expr to constant directly. /// @@ -40,42 +48,7 @@ fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) } -fn make_map(args: &[ColumnarValue]) -> Result { - let can_evaluate_to_const = can_evaluate_to_const(args); - - let (key, value): (Vec<_>, Vec<_>) = args - .chunks_exact(2) - .map(|chunk| { - if let ColumnarValue::Array(_) = chunk[0] { - return not_impl_err!("make_map does not support array keys"); - } - if let ColumnarValue::Array(_) = chunk[1] { - return not_impl_err!("make_map does not support array values"); - } - Ok((chunk[0].clone(), chunk[1].clone())) - }) - .collect::>>()? - .into_iter() - .unzip(); - - let keys = ColumnarValue::values_to_arrays(&key)?; - let values = ColumnarValue::values_to_arrays(&value)?; - - let keys: Vec<_> = keys.iter().map(|k| k.as_ref()).collect(); - let values: Vec<_> = values.iter().map(|v| v.as_ref()).collect(); - - let key = match concat(&keys) { - Ok(key) => key, - Err(e) => return internal_err!("Error concatenating keys: {}", e), - }; - let value = match concat(&values) { - Ok(value) => value, - Err(e) => return internal_err!("Error concatenating values: {}", e), - }; - make_map_batch_internal(key, value, can_evaluate_to_const) -} - -fn make_map_batch(args: &[ColumnarValue]) -> Result { +fn make_map_batch(args: &[ColumnarValue]) -> datafusion_common::Result { if args.len() != 2 { return exec_err!( "make_map requires exactly 2 arguments, got {} instead", @@ -90,7 +63,9 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { make_map_batch_internal(key, value, can_evaluate_to_const) } -fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { +fn get_first_array_ref( + columnar_value: &ColumnarValue, +) -> datafusion_common::Result { match columnar_value { ColumnarValue::Scalar(value) => match value { ScalarValue::List(array) => Ok(array.value(0)), @@ -106,7 +81,7 @@ fn make_map_batch_internal( keys: ArrayRef, values: ArrayRef, can_evaluate_to_const: bool, -) -> Result { +) -> datafusion_common::Result { if keys.null_count() > 0 { return exec_err!("map key cannot be null"); } @@ -154,115 +129,6 @@ fn make_map_batch_internal( }) } -#[derive(Debug)] -pub struct MakeMap { - signature: Signature, -} - -impl Default for MakeMap { - fn default() -> Self { - Self::new() - } -} - -impl MakeMap { - pub fn new() -> Self { - Self { - signature: Signature::user_defined(Volatility::Immutable), - } - } -} - -impl ScalarUDFImpl for MakeMap { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "make_map" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.is_empty() { - return exec_err!( - "make_map requires at least one pair of arguments, got 0 instead" - ); - } - if arg_types.len() % 2 != 0 { - return exec_err!( - "make_map requires an even number of arguments, got {} instead", - arg_types.len() - ); - } - - let key_type = &arg_types[0]; - let mut value_type = &arg_types[1]; - - for (i, chunk) in arg_types.chunks_exact(2).enumerate() { - if chunk[0].is_null() { - return exec_err!("make_map key cannot be null at position {}", i); - } - if &chunk[0] != key_type { - return exec_err!( - "make_map requires all keys to have the same type {}, got {} instead at position {}", - key_type, - chunk[0], - i - ); - } - - if !chunk[1].is_null() { - if value_type.is_null() { - value_type = &chunk[1]; - } else if &chunk[1] != value_type { - return exec_err!( - "map requires all values to have the same type {}, got {} instead at position {}", - value_type, - &chunk[1], - i - ); - } - } - } - - let mut result = Vec::new(); - for _ in 0..arg_types.len() / 2 { - result.push(key_type.clone()); - result.push(value_type.clone()); - } - - Ok(result) - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - let key_type = &arg_types[0]; - let mut value_type = &arg_types[1]; - - for chunk in arg_types.chunks_exact(2) { - if !chunk[1].is_null() && value_type.is_null() { - value_type = &chunk[1]; - } - } - - let mut builder = SchemaBuilder::new(); - builder.push(Field::new("key", key_type.clone(), false)); - builder.push(Field::new("value", value_type.clone(), true)); - let fields = builder.finish().fields; - Ok(DataType::Map( - Arc::new(Field::new("entries", DataType::Struct(fields), false)), - false, - )) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_map(args) - } -} - #[derive(Debug)] pub struct MapFunc { signature: Signature, @@ -295,7 +161,7 @@ impl ScalarUDFImpl for MapFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { if arg_types.len() % 2 != 0 { return exec_err!( "map requires an even number of arguments, got {} instead", @@ -320,12 +186,12 @@ impl ScalarUDFImpl for MapFunc { )) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { make_map_batch(args) } } -fn get_element_type(data_type: &DataType) -> Result<&DataType> { +fn get_element_type(data_type: &DataType) -> datafusion_common::Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), DataType::LargeList(element) => Ok(element.data_type()), diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index cfbe99b4b7fd..3f779c9f111e 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -17,14 +17,17 @@ //! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] -use datafusion_common::{utils::list_ndims, DFSchema, Result}; +use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ + expr::AggregateFunctionDefinition, planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, - sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, + sqlparser, Expr, ExprSchemable, GetFieldAccess, }; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use crate::map::map_udf; use crate::{ array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, @@ -97,6 +100,21 @@ impl ExprPlanner for ArrayFunctionPlanner { ) -> Result>> { Ok(PlannerResult::Planned(make_array(exprs))) } + + fn plan_make_map(&self, args: Vec) -> Result>> { + if args.len() % 2 != 0 { + return exec_err!("make_map requires an even number of arguments"); + } + + let (keys, values): (Vec<_>, Vec<_>) = + args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0); + let keys = make_array(keys.into_iter().map(|(_, e)| e).collect()); + let values = make_array(values.into_iter().map(|(_, e)| e).collect()); + + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(map_udf(), vec![keys, values]), + ))) + } } pub struct FieldAccessPlanner; @@ -153,8 +171,9 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) + if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def { + return udf.name() == "array_agg"; + } + + false } diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 589dd4d0c41c..0b7cfc283c06 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -228,7 +228,7 @@ fn array_remove_internal( } } -/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// /// The type of each **element** in `list_array` must be the same as the type of diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index b143080b1962..0281676cabf2 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -141,8 +141,3 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] - -[[bench]] -harness = false -name = "map" -required-features = ["core_expressions"] diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 9c410d4e18e8..9227f9e3a2a8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -444,7 +444,7 @@ fn is_separator(c: char) -> bool { } #[derive(Debug)] -/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing +/// Splits a strings like Dictionary(Int32, Int64) into tokens suitable for parsing /// /// For example the string "Timestamp(Nanosecond, None)" would be parsed into: /// diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 31bce04beec1..8c5121397284 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -25,7 +25,6 @@ pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; pub mod getfield; -pub mod map; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -43,8 +42,6 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(map::MakeMap, MAKE_MAP, make_map); -make_udf_function!(map::MapFunc, MAP, map); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -81,14 +78,6 @@ pub mod expr_fn { coalesce, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", args, - ),( - make_map, - "Returns a map created from the given keys and values pairs. This function isn't efficient for large maps. Use the `map` function instead.", - args, - ),( - map, - "Returns a map created from a key list and a value list", - args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -105,9 +94,6 @@ pub fn functions() -> Vec> { nvl2(), arrow_typeof(), named_struct(), - get_field(), coalesce(), - make_map(), - map(), ] } diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 63eaa9874c2b..889f191d592f 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::DFSchema; +use arrow::datatypes::Field; use datafusion_common::Result; +use datafusion_common::{not_impl_err, Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; -use datafusion_expr::Expr; +use datafusion_expr::{lit, Expr}; use super::named_struct; @@ -62,4 +63,26 @@ impl ExprPlanner for CoreFunctionPlanner { ScalarFunction::new_udf(crate::string::overlay(), args), ))) } + + fn plan_compound_identifier( + &self, + field: &Field, + qualifier: Option<&TableReference>, + nested_names: &[String], + ) -> Result>> { + // TODO: remove when can support multiple nested identifiers + if nested_names.len() > 1 { + return not_impl_err!( + "Nested identifiers not yet supported for column {}", + Column::from((qualifier, field)).quoted_flat_name() + ); + } + let nested_name = nested_names[0].to_string(); + + let col = Expr::Column(Column::from((qualifier, field))); + let get_field_args = vec![col, lit(ScalarValue::from(nested_name))]; + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::core::get_field(), get_field_args), + ))) + } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index c84d1015bd7e..634e28e6f393 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -84,7 +84,7 @@ impl ToLocalTimeFunc { let arg_type = time_value.data_type(); match arg_type { DataType::Timestamp(_, None) => { - // if no timezone specificed, just return the input + // if no timezone specified, just return the input Ok(time_value.clone()) } // If has timezone, adjust the underlying time value. The current time value @@ -165,7 +165,7 @@ impl ToLocalTimeFunc { match array.data_type() { Timestamp(_, None) => { - // if no timezone specificed, just return the input + // if no timezone specified, just return the input Ok(time_value.clone()) } Timestamp(Nanosecond, Some(_)) => { diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ea424c14749e..0e181aa61250 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -82,10 +82,16 @@ impl ScalarUDFImpl for LogFunc { } fn output_ordering(&self, input: &[ExprProperties]) -> Result { - match (input[0].sort_properties, input[1].sort_properties) { - (first @ SortProperties::Ordered(value), SortProperties::Ordered(base)) - if !value.descending && base.descending - || value.descending && !base.descending => + let (base_sort_properties, num_sort_properties) = if input.len() == 1 { + // log(x) defaults to log(10, x) + (SortProperties::Singleton, input[0].sort_properties) + } else { + (input[0].sort_properties, input[1].sort_properties) + }; + match (num_sort_properties, base_sort_properties) { + (first @ SortProperties::Ordered(num), SortProperties::Ordered(base)) + if num.descending != base.descending + && num.nulls_first == base.nulls_first => { Ok(first) } @@ -230,6 +236,7 @@ mod tests { use super::*; + use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -334,4 +341,112 @@ mod tests { assert_eq!(args[0], lit(2)); assert_eq!(args[1], lit(3)); } + + #[test] + fn test_log_output_ordering() { + // [Unordered, Ascending, Descending, Literal] + let orders = vec![ + ExprProperties::new_unknown(), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: false, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: true, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Singleton), + ]; + + let log = LogFunc::new(); + + // Test log(num) + for order in orders.iter().cloned() { + let result = log.output_ordering(&[order.clone()]).unwrap(); + assert_eq!(result, order.sort_properties); + } + + // Test log(base, num), where `nulls_first` is the same + let mut results = Vec::with_capacity(orders.len() * orders.len()); + for base_order in orders.iter() { + for num_order in orders.iter().cloned() { + let result = log + .output_ordering(&[base_order.clone(), num_order]) + .unwrap(); + results.push(result); + } + } + let expected = vec![ + // base: Unordered + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + // base: Ascending, num: Unordered + SortProperties::Unordered, + // base: Ascending, num: Ascending + SortProperties::Unordered, + // base: Ascending, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Ascending, num: Literal + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Descending, num: Unordered + SortProperties::Unordered, + // base: Descending, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Descending, num: Descending + SortProperties::Unordered, + // base: Descending, num: Literal + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Unordered + SortProperties::Unordered, + // base: Literal, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Literal, num: Literal + SortProperties::Singleton, + ]; + assert_eq!(results, expected); + + // Test with different `nulls_first` + let base_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + ); + let num_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: false, + }), + ); + assert_eq!( + log.output_ordering(&[base_order, num_order]).unwrap(), + SortProperties::Unordered + ); + } } diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 378b6ced076c..d820f991be18 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -562,7 +562,7 @@ mod tests { #[test] fn test_static_pattern_regexp_replace_pattern_error() { let values = StringArray::from(vec!["abc"; 5]); - // Delibaretely using an invalid pattern to see how the single pattern + // Deliberately using an invalid pattern to see how the single pattern // error is propagated on regexp_replace. let patterns = StringArray::from(vec!["["; 5]); let replacements = StringArray::from(vec!["foo"; 5]); diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index a057e4298546..f8ecab9073c4 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -122,15 +122,15 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); let length = if n > 0 { - let splitted = string.split(delimiter); - splitted + let split = string.split(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() - delimiter.len() } else { - let splitted = string.rsplit(delimiter); - splitted + let split = string.rsplit(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 5aacfaf59cb1..61bc1cd70145 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -17,320 +17,6 @@ under the License. --> -# DataFusion Query Optimizer +Please see [Query Optimizer] in the Library User Guide -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory -format. - -DataFusion has modular design, allowing individual crates to be re-used in other projects. - -This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and -contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so -they execute more quickly while still computing the same result. - -## Running the Optimizer - -The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules -and applying it to a logical plan to produce an optimized logical plan. - -```rust - -// We need a logical plan as the starting point. There are many ways to build a logical plan: -// -// The `datafusion-expr` crate provides a LogicalPlanBuilder -// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL -// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan -let logical_plan = ... - -let mut config = OptimizerContext::default(); -let optimizer = Optimizer::new(&config); -let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}", - rule.name(), - plan.display_indent() - ) -} -``` - -## Providing Custom Rules - -The optimizer can be created with a custom set of rules. - -```rust -let optimizer = Optimizer::with_rules(vec![ - Arc::new(MyRule {}) -]); -``` - -## Writing Optimization Rules - -Please refer to the -[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) -example to learn more about the general approach to writing optimizer rules and -then move onto studying the existing rules. - -All rules must implement the `OptimizerRule` trait. - -```rust -/// `OptimizerRule` transforms one ['LogicalPlan'] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer can simply return it as is. -pub trait OptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result; - - /// A human readable name for this optimizer rule - fn name(&self) -> &str; -} -``` - -### General Guidelines - -Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate -individual operators or expressions. - -Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs -the actual optimization. This approach is used in projection push down and filter push down. - -### Expression Naming - -Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output -contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: - -```text -> select count(c9) from aggregate_test_100; -+------------------------------+ -| COUNT(aggregate_test_100.c9) | -+------------------------------+ -| 100 | -+------------------------------+ -``` - -These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan -to another. For example: - -```text -> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; -+--------------------------------------------+ -| sq.COUNT(aggregate_test_100.c9) + Int64(1) | -+--------------------------------------------+ -| 101 | -+--------------------------------------------+ -``` - -### Implication - -Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are -not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten -expression by adding an alias. - -Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be -displayed the same as `1 + 2`: - -```text -> select 1 + 2; -+---------------------+ -| Int64(1) + Int64(2) | -+---------------------+ -| 3 | -+---------------------+ -``` - -Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively -`3 as "1 + 2"`: - -```text -> explain select 1 + 2; -+---------------+-------------------------------------------------+ -| plan_type | plan | -+---------------+-------------------------------------------------+ -| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | -| | EmptyRelation | -| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | PlaceholderRowExec | -| | | -+---------------+-------------------------------------------------+ -``` - -If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) -and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. - -### Building Expression Names - -There are currently two ways to create a name for an expression in the logical plan. - -```rust -impl Expr { - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - pub fn display_name(&self) -> Result { - create_name(self) - } - - /// Returns a full and complete string representation of this expression. - pub fn canonical_name(&self) -> String { - format!("{}", self) - } -} -``` - -When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a -name to be used in a schema, `display_name` should be used. - -### Utilities - -There are a number of utility methods provided that take care of some common tasks. - -### ExprVisitor - -The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. - -Here is an example that demonstrates this. - -```rust -fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { - struct InSubqueryVisitor<'a> { - accum: &'a mut Vec, - } - - impl ExpressionVisitor for InSubqueryVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { - if let Expr::InSubquery(_) = expr { - self.accum.push(expr.to_owned()); - } - Ok(Recursion::Continue(self)) - } - } - - expression.accept(InSubqueryVisitor { accum: extracted })?; - Ok(()) -} -``` - -### Rewriting Expressions - -The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied -to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). - -The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, -consuming `self` producing a new expression. - -```rust -let mut expr_rewriter = MyExprRewriter {}; -let expr = expr.rewrite(&mut expr_rewriter)?; -``` - -Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the -implementation does not need to perform any recursion since this is handled by the `rewrite` method. - -```rust -struct MyExprRewriter {} - -impl ExprRewriter for MyExprRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Between { - negated, - expr, - low, - high, - } => { - let expr: Expr = expr.as_ref().clone(); - let low: Expr = low.as_ref().clone(); - let high: Expr = high.as_ref().clone(); - if negated { - Ok(expr.clone().lt(low).or(expr.clone().gt(high))) - } else { - Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) - } - } - _ => Ok(expr.clone()), - } - } -} -``` - -### optimize_children - -Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate -that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on -the plan's children and then returns a node of the same type. - -```rust -fn optimize( - &self, - plan: &LogicalPlan, - _config: &mut OptimizerConfig, -) -> Result { - // recurse down and optimize children first - let plan = utils::optimize_children(self, plan, _config)?; - - ... -} -``` - -### Writing Tests - -There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan -in isolation (without any other rule being applied). - -There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. - -### Debugging - -The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. - -In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. - -```text -> explain verbose select cast(1 + 2.2 as string) as foo; -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| plan_type | plan | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | -| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | -| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | -| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | -| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | -| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | -| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | -| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | -| logical_plan after projection_push_down | SAME TEXT AS ABOVE | -| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | -| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | -| logical_plan after filter_push_down | SAME TEXT AS ABOVE | -| logical_plan after limit_push_down | SAME TEXT AS ABOVE | -| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | -| logical_plan | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | -| physical_plan after join_selection | SAME TEXT AS ABOVE | -| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | -| physical_plan after repartition | SAME TEXT AS ABOVE | -| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | -| physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -``` - -[df]: https://crates.io/crates/datafusion +[query optimizer]: https://datafusion.apache.org/library-user-guide/query-optimizer.html diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 959ffdaaa212..fa8aeb86ed31 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -35,7 +35,7 @@ pub struct CountWildcardRule {} impl CountWildcardRule { pub fn new() -> Self { - CountWildcardRule {} + Self {} } } @@ -59,14 +59,14 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) + } if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) + if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn analyze_internal(plan: LogicalPlan) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index db39f8f7737d..9856ea271ca5 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -159,11 +159,11 @@ fn check_inner_plan( let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() .partition(|e| e.contains_outer()); - let maybe_unsupport = correlated + let maybe_unsupported = correlated .into_iter() .filter(|expr| !can_pullup_over_aggregation(expr)) .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupport.is_empty() { + if is_aggregate && is_scalar && !maybe_unsupported.is_empty() { return plan_err!( "Correlated column is not allowed in predicate: {predicate}" ); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 80a8c864e431..50fb1b8193ce 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -103,6 +103,14 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); + // Coerce filter predicates to boolean (handles `WHERE NULL`) + let plan = if let LogicalPlan::Filter(mut filter) = plan { + filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; + LogicalPlan::Filter(filter) + } else { + plan + }; + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); let name_preserver = NamePreserver::new(&plan); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e4b36652974d..bbf2091c2217 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -248,7 +248,7 @@ impl CommonSubexprEliminate { } /// Rewrites the expression in `exprs_list` with common sub-expressions - /// replaced with a new colum and adds a ProjectionExec on top of `input` + /// replaced with a new column and adds a ProjectionExec on top of `input` /// which computes any replaced common sub-expressions. /// /// Returns a tuple of: @@ -636,7 +636,7 @@ impl CommonSubexprEliminate { /// Returns the window expressions, and the input to the deepest child /// LogicalPlan. /// -/// For example, if the input widnow looks like +/// For example, if the input window looks like /// /// ```text /// LogicalPlan::Window(exprs=[a, b, c]) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 4e3ca7e33a2e..b6d49490d437 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -1232,7 +1232,7 @@ mod tests { } #[test] - fn in_subquery_muti_project_subquery_cols() -> Result<()> { + fn in_subquery_multi_project_subquery_cols() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index 3f32a0c36a9a..a9a18898c82e 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -160,7 +160,7 @@ impl RequiredIndicies { (l, r.map_indices(|idx| idx - n)) } - /// Partitions the indicies in this instance into two groups based on the + /// Partitions the indices in this instance into two groups based on the /// given predicate function `f`. fn partition(&self, f: F) -> (Self, Self) where diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 20e2ac07dffd..ad9be449d9ab 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -261,8 +261,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::Unnest(_) - | Expr::ScalarFunction(_) => { + | Expr::Unnest(_) => { is_evaluate = false; Ok(TreeNodeRecursion::Stop) } @@ -284,7 +283,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Case(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), + | Expr::InList { .. } + | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -1020,7 +1020,7 @@ impl OptimizerRule for PushDownFilter { /// ``` fn rewrite_projection( predicates: Vec, - projection: Projection, + mut projection: Projection, ) -> Result<(Transformed, Option)> { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. @@ -1053,11 +1053,13 @@ fn rewrite_projection( // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" let new_filter = LogicalPlan::Filter(Filter::try_new( replace_cols_by_name(expr, &non_volatile_map)?, - Arc::clone(&projection.input), + std::mem::take(&mut projection.input), )?); + projection.input = Arc::new(new_filter); + Ok(( - insert_below(LogicalPlan::Projection(projection), new_filter)?, + Transformed::yes(LogicalPlan::Projection(projection)), conjunction(keep_predicates), )) } @@ -1913,7 +1915,7 @@ mod tests { assert_optimized_plan_eq(plan, expected) } - /// post-join predicates with columns from both sides are converted to join filterss + /// post-join predicates with columns from both sides are converted to join filters #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8414f39f3060..56556f387d1b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -478,7 +478,7 @@ struct ConstEvaluator<'a> { #[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { - // Expr was simplifed and contains the new expression + // Expr was simplified and contains the new expression Simplified(ScalarValue), // Expr was not simplified and original value is returned NotSimplified(ScalarValue), @@ -519,7 +519,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting - // and may not evalute all their sub expressions. Thus if + // and may not evaluate all their sub expressions. Thus if // if any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => { diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 7238dd5bbd97..e0f50a470d43 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -893,7 +893,7 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, utc) } - // a dictonary type for storing string tags + // a dictionary type for storing string tags fn dictionary_tag_type() -> DataType { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) } diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs index f109079f6a26..3fcd570f514e 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs @@ -410,7 +410,7 @@ pub fn accumulate_indices( }, ); - // handle any remaining bits (after the intial 64) + // handle any remaining bits (after the initial 64) let remainder_bits = bit_chunks.remainder_bits(); group_indices_remainder .iter() @@ -835,7 +835,7 @@ mod test { } } - /// Parallel implementaiton of NullState to check expected values + /// Parallel implementation of NullState to check expected values #[derive(Debug, Default)] struct MockNullState { /// group indices that had values that passed the filter diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0e245fd0a66a..8c5f9f9e5a7e 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -23,7 +23,7 @@ pub mod tdigest; pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::ReversedUDAF; @@ -51,6 +51,10 @@ use datafusion_expr::utils::AggregateOrderSensitivity; /// /// `input_exprs` and `sort_exprs` are used for customizing Accumulator as the arguments in `AccumulatorArgs`, /// if you don't need them it is fine to pass empty slice `&[]`. +/// +/// `is_reversed` is used to indicate whether the aggregation is running in reverse order, +/// it could be used to hint Accumulator to accumulate in the reversed order, +/// you can just set to false if you are not reversing expression #[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, @@ -62,6 +66,7 @@ pub fn create_aggregate_expr( name: impl Into, ignore_nulls: bool, is_distinct: bool, + is_reversed: bool, ) -> Result> { debug_assert_eq!(sort_exprs.len(), ordering_req.len()); @@ -81,6 +86,61 @@ pub fn create_aggregate_expr( .map(|e| e.expr.data_type(schema)) .collect::>>()?; + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + let name = name.into(); + + Ok(Arc::new(AggregateFunctionExpr { + fun: fun.clone(), + args: input_phy_exprs.to_vec(), + logical_args: input_exprs.to_vec(), + data_type: fun.return_type(&input_exprs_types)?, + name, + schema: schema.clone(), + dfschema: DFSchema::empty(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), + ignore_nulls, + ordering_fields, + is_distinct, + input_type: input_exprs_types[0].clone(), + is_reversed, + })) +} + +#[allow(clippy::too_many_arguments)] +// This is not for external usage, consider creating with `create_aggregate_expr` instead. +pub fn create_aggregate_expr_with_dfschema( + fun: &AggregateUDF, + input_phy_exprs: &[Arc], + input_exprs: &[Expr], + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + dfschema: &DFSchema, + name: impl Into, + ignore_nulls: bool, + is_distinct: bool, + is_reversed: bool, +) -> Result> { + debug_assert_eq!(sort_exprs.len(), ordering_req.len()); + + let schema: Schema = dfschema.into(); + + let input_exprs_types = input_phy_exprs + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + let ordering_fields = ordering_fields(ordering_req, &ordering_types); Ok(Arc::new(AggregateFunctionExpr { @@ -90,12 +150,14 @@ pub fn create_aggregate_expr( data_type: fun.return_type(&input_exprs_types)?, name: name.into(), schema: schema.clone(), + dfschema: dfschema.clone(), sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), ignore_nulls, ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), + is_reversed, })) } @@ -261,6 +323,7 @@ pub struct AggregateFunctionExpr { data_type: DataType, name: String, schema: Schema, + dfschema: DFSchema, // The logical order by expressions sort_exprs: Vec, // The physical order by expressions @@ -270,6 +333,7 @@ pub struct AggregateFunctionExpr { // fields used for order sensitive aggregation functions ordering_fields: Vec, is_distinct: bool, + is_reversed: bool, input_type: DataType, } @@ -288,6 +352,11 @@ impl AggregateFunctionExpr { pub fn ignore_nulls(&self) -> bool { self.ignore_nulls } + + /// Return if the aggregation is distinct + pub fn is_reversed(&self) -> bool { + self.is_reversed + } } impl AggregateExpr for AggregateFunctionExpr { @@ -320,12 +389,14 @@ impl AggregateExpr for AggregateFunctionExpr { let acc_args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.accumulator(acc_args) @@ -335,18 +406,20 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; let accumulator = self.fun.create_sliding_accumulator(args)?; // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to + // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to // implement retract_batch method in order to run correctly // currently in DataFusion. // @@ -377,7 +450,7 @@ impl AggregateExpr for AggregateFunctionExpr { // 3. Third sum we add to the state sum value between `[2, 3)` // (`[0, 2)` is already in the state sum). Also we need to // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the apropriate range. + // between [1, 3) which is indeed the appropriate range. // // When we use `UNBOUNDED PRECEDING` in the query starting // index will always be 0 for the desired range, and hence the @@ -405,12 +478,14 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.groups_accumulator_supported(args) } @@ -419,12 +494,14 @@ impl AggregateExpr for AggregateFunctionExpr { let args = AccumulatorArgs { data_type: &self.data_type, schema: &self.schema, + dfschema: &self.dfschema, ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, input_exprs: &self.logical_args, name: &self.name, + is_reversed: self.is_reversed, }; self.fun.create_groups_accumulator(args) } @@ -462,16 +539,17 @@ impl AggregateExpr for AggregateFunctionExpr { else { return Ok(None); }; - create_aggregate_expr( + create_aggregate_expr_with_dfschema( &updated_fn, &self.args, &self.logical_args, &self.sort_exprs, &self.ordering_req, - &self.schema, + &self.dfschema, self.name(), self.ignore_nulls, self.is_distinct, + self.is_reversed, ) .map(Some) } @@ -495,18 +573,24 @@ impl AggregateExpr for AggregateFunctionExpr { }) .collect::>(); let mut name = self.name().to_string(); - replace_order_by_clause(&mut name); + // If the function is changed, we need to reverse order_by clause as well + // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) + if self.fun().name() == reverse_udf.name() { + } else { + replace_order_by_clause(&mut name); + } replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); - let reverse_aggr = create_aggregate_expr( + let reverse_aggr = create_aggregate_expr_with_dfschema( &reverse_udf, &self.args, &self.logical_args, &reverse_sort_exprs, &reverse_ordering_req, - &self.schema, + &self.dfschema, name, self.ignore_nulls, self.is_distinct, + !self.is_reversed, ) .unwrap(); diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index bff571f5b5be..23280701013d 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -355,7 +355,7 @@ where assert_eq!(values.len(), batch_hashes.len()); for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // hande null value + // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload @@ -439,7 +439,7 @@ where // Put the small values into buffer and offsets so it // appears the output array, and store that offset // so the bytes can be compared if needed - let offset = self.buffer.len(); // offset of start fof data + let offset = self.buffer.len(); // offset of start for data self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/physical-expr-common/src/expressions/column.rs index 956c33d59b20..5397599ea2dc 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/physical-expr-common/src/expressions/column.rs @@ -80,7 +80,7 @@ impl PhysicalExpr for Column { Ok(input_schema.field(self.index).data_type().clone()) } - /// Decide whehter this expression is nullable, given the schema of the input + /// Decide whether this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { self.bounds_check(input_schema)?; Ok(input_schema.field(self.index).is_nullable()) @@ -135,3 +135,49 @@ impl Column { pub fn col(name: &str, schema: &Schema) -> Result> { Ok(Arc::new(Column::new_with_schema(name, schema)?)) } + +#[cfg(test)] +mod test { + use super::Column; + use crate::physical_expr::PhysicalExpr; + + use arrow::array::StringArray; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::Result; + + use std::sync::Arc; + + #[test] + fn out_of_bounds_data_type() { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let col = Column::new("id", 9); + let error = col.data_type(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + } + + #[test] + fn out_of_bounds_nullable() { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let col = Column::new("id", 9); + let error = col.nullable(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + } + + #[test] + fn out_of_bounds_evaluate() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); + let data: StringArray = vec!["data"].into(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; + let col = Column::new("id", 9); + let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); + Ok(()) + } +} diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 1998f1439646..c74fb9c2d1b7 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -20,13 +20,15 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::expressions::column::Column; use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, not_impl_err, plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::ColumnarValue; @@ -191,6 +193,33 @@ pub fn with_new_children_if_necessary( } } +/// Rewrites an expression according to new schema; i.e. changes the columns it +/// refers to with the column at corresponding index in the new schema. Returns +/// an error if the given schema has fewer columns than the original schema. +/// Note that the resulting expression may not be valid if data types in the +/// new schema is incompatible with expression nodes. +pub fn with_new_schema( + expr: Arc, + schema: &SchemaRef, +) -> Result> { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let idx = col.index(); + let Some(field) = schema.fields().get(idx) else { + return plan_err!( + "New schema has fewer columns than original schema" + ); + }; + let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) + } + })? + .data) +} + pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { if any.is::>() { any.downcast_ref::>() diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 8fb1356a8092..2b506b74216f 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -22,12 +22,12 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use crate::utils::limited_convert_logical_expr_to_physical_expr; +use crate::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, DFSchema, Result}; use datafusion_expr::{ColumnarValue, Expr}; /// Represents Sort operation for a column in a RecordBatch @@ -275,9 +275,9 @@ pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; /// Converts each [`Expr::Sort`] into a corresponding [`PhysicalSortExpr`]. /// Returns an error if the given logical expression is not a [`Expr::Sort`]. -pub fn limited_convert_logical_sort_exprs_to_physical( +pub fn limited_convert_logical_sort_exprs_to_physical_with_dfschema( exprs: &[Expr], - schema: &Schema, + dfschema: &DFSchema, ) -> Result> { // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; @@ -286,7 +286,10 @@ pub fn limited_convert_logical_sort_exprs_to_physical( return exec_err!("Expects to receive sort expression"); }; sort_exprs.push(PhysicalSortExpr::new( - limited_convert_logical_expr_to_physical_expr(sort.expr.as_ref(), schema)?, + limited_convert_logical_expr_to_physical_expr_with_dfschema( + sort.expr.as_ref(), + dfschema, + )?, SortOptions { descending: !sort.asc, nulls_first: sort.nulls_first, diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 44622bd309df..0978a906a5dc 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -19,15 +19,15 @@ use std::sync::Arc; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; -use arrow::datatypes::Schema; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, DFSchema, Result}; use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; +use crate::expressions::column::Column; use crate::expressions::literal::Literal; -use crate::expressions::{self, CastExpr}; +use crate::expressions::CastExpr; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::PhysicalSortExpr; use crate::tree_node::ExprContext; @@ -110,19 +110,22 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec`. /// If conversion is not supported yet, returns Error. -pub fn limited_convert_logical_expr_to_physical_expr( +pub fn limited_convert_logical_expr_to_physical_expr_with_dfschema( expr: &Expr, - schema: &Schema, + dfschema: &DFSchema, ) -> Result> { match expr { - Expr::Alias(Alias { expr, .. }) => { - Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?) + Expr::Alias(Alias { expr, .. }) => Ok( + limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, dfschema)?, + ), + Expr::Column(col) => { + let idx = dfschema.index_of_column(col)?; + Ok(Arc::new(Column::new(&col.name, idx))) } - Expr::Column(col) => expressions::column::col(&col.name, schema), Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( - limited_convert_logical_expr_to_physical_expr( + limited_convert_logical_expr_to_physical_expr_with_dfschema( cast_expr.expr.as_ref(), - schema, + dfschema, )?, cast_expr.data_type.clone(), None, diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 9cc7bdc465fb..862edd9c1fac 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -40,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { // create input data let mut c1 = Int32Builder::new(); let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); for i in 0..1000 { c1.append_value(i); if i % 7 == 0 { @@ -47,14 +48,21 @@ fn criterion_benchmark(c: &mut Criterion) { } else { c2.append_value(&format!("string {i}")); } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(&format!("other string {i}")); + } } let c1 = Arc::new(c1.finish()); let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); let schema = Schema::new(vec![ Field::new("c1", DataType::Int32, true), Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), ]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); // use same predicate for all benchmarks let predicate = Arc::new(BinaryExpr::new( @@ -63,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_lit_i32(500), )); - // CASE WHEN expr THEN 1 ELSE 0 END + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END c.bench_function("case_when: scalar or scalar", |b| { let expr = Arc::new( CaseExpr::try_new( @@ -76,13 +84,38 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) }); - // CASE WHEN expr THEN col ELSE null END + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { let expr = Arc::new( CaseExpr::try_new( None, vec![(predicate.clone(), make_col("c2", 1))], - Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))), + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END + c.bench_function("case_when: CASE expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + Some(make_col("c1", 0)), + vec![ + (make_lit_i32(1), make_col("c2", 1)), + (make_lit_i32(2), make_col("c3", 2)), + ], + None, ) .unwrap(), ); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index 0d5ed730e283..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array_nullable; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - - let val = Arc::clone(&values[0]); - if val.len() > 0 { - self.values.push(val); - } - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array_nullable(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs deleted file mode 100644 index eca6e4ce4f65..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ /dev/null @@ -1,433 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ) -> Self { - let name = name.into(); - Self { - name, - input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, - datatype: DataType, -} - -impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: HashSet::new(), - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for DistinctArrayAggAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - - let array = &values[0]; - - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - states[0] - .as_list::() - .iter() - .flatten() - .try_for_each(|val| self.update_batch(&[val])) - } - - fn evaluate(&mut self) -> Result { - let values: Vec = self.values.iter().cloned().collect(); - if values.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - let arr = ScalarValue::new_list(&values, &self.datatype, true); - Ok(ScalarValue::List(arr)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::array::Int32Array; - use arrow::datatypes::Schema; - use arrow::record_batch::RecordBatch; - use arrow_array::types::Int32Type; - use arrow_array::Array; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; - use datafusion_common::internal_err; - - // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. - fn compare_list_contents( - expected: Vec, - actual: ScalarValue, - ) -> Result<()> { - let array = actual.to_array()?; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - let mut actual_scalars = vec![]; - for index in 0..inner_array.len() { - let sv = ScalarValue::try_from_array(&inner_array, index)?; - actual_scalars.push(sv); - } - - if actual_scalars.len() != expected.len() { - return internal_err!( - "Expected and actual list lengths differ: expected={}, actual={}", - expected.len(), - actual_scalars.len() - ); - } - - let mut seen = vec![false; expected.len()]; - for v in expected { - let mut found = false; - for (i, sv) in actual_scalars.iter().enumerate() { - if sv == &v { - seen[i] = true; - found = true; - break; - } - } - if !found { - return internal_err!( - "Expected value {:?} not found in actual values {:?}", - v, - actual_scalars - ); - } - } - - Ok(()) - } - - fn check_distinct_array_agg( - input: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - let actual = aggregate(&batch, agg)?; - compare_list_contents(expected, actual) - } - - fn check_merge_distinct_array_agg( - input1: ArrayRef, - input2: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - - let mut accum1 = agg.create_accumulator()?; - let mut accum2 = agg.create_accumulator()?; - - accum1.update_batch(&[input1])?; - accum2.update_batch(&[input2])?; - - let array = accum2.state()?[0].raw_data()?; - accum1.merge_batch(&[array])?; - - let actual = accum1.evaluate()?; - compare_list_contents(expected, actual) - } - - #[test] - fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ]; - - check_distinct_array_agg(col, expected, DataType::Int32) - } - - #[test] - fn merge_distinct_array_agg_i32() -> Result<()> { - let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]; - - check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) - } - - #[test] - fn distinct_array_agg_nested() -> Result<()> { - // [[1, 2, 3], [4, 5]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[6], [7, 8]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[9]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 and l3 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![ - l1.clone(), - l2.clone(), - l3.clone(), - l3.clone(), - l1.clone(), - ]) - .unwrap(); - let expected = vec![l1, l2, l3]; - - check_distinct_array_agg( - array, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) - } - - #[test] - fn merge_distinct_array_agg_nested() -> Result<()> { - // [[1, 2], [3, 4]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(3), - Some(4), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - // [[6, 7], [8]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(6), - Some(7), - ])]); - let a2 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 in the input array and check that it is deduped in the output. - let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap(); - let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap(); - - let expected = vec![l1, l2, l3]; - - check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ef21b3d0f788..bdc41ff0a9bc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::Result; use datafusion_expr::AggregateFunction; use crate::expressions::{self}; @@ -42,7 +42,7 @@ pub fn create_aggregate_expr( fun: &AggregateFunction, distinct: bool, input_phy_exprs: &[Arc], - ordering_req: &[PhysicalSortExpr], + _ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, _ignore_nulls: bool, @@ -54,36 +54,8 @@ pub fn create_aggregate_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; let data_type = input_phy_types[0].clone(); - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(input_schema)) - .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::ArrayAgg, false) => { - let expr = Arc::clone(&input_phy_exprs[0]); - - if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) - } else { - Arc::new(expressions::OrderSensitiveArrayAgg::new( - expr, - name, - data_type, - ordering_types, - ordering_req.to_vec(), - )) - } - } - (AggregateFunction::ArrayAgg, true) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - let expr = Arc::clone(&input_phy_exprs[0]); - Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type)) - } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), name, @@ -104,70 +76,9 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Max, Min}; use super::*; - #[test] - fn test_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ArrayAgg]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - - let result_distinct = create_physical_agg_expr_for_test( - &fun, - true, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - } - } - Ok(()) - } #[test] fn test_min_max_expr() -> Result<()> { @@ -211,7 +122,6 @@ mod tests { result_agg_phy_exprs.field().unwrap() ); } - _ => {} }; } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs index 9856e1c989b3..592c130b69d8 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -69,7 +69,7 @@ impl AccumulatorState { } } - /// Returns the amount of memory taken by this structre and its accumulator + /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { self.accumulator.size() + std::mem::size_of_val(self) diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 65bb9e478c3d..9987e97b38d3 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -296,7 +296,7 @@ macro_rules! typed_min_max_batch_string { }}; } -// Statically-typed version of min/max(array) -> ScalarValue for binay types. +// Statically-typed version of min/max(array) -> ScalarValue for binary types. macro_rules! typed_min_max_batch_binary { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b9d803900f53..264c48513050 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; - -pub(crate) mod array_agg; -pub(crate) mod array_agg_distinct; -pub(crate) mod array_agg_ordered; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; @@ -33,3 +28,5 @@ pub mod utils { get_sort_options, ordering_fields, DecimalAverager, Hashable, }; } + +pub use datafusion_physical_expr_common::aggregate::AggregateExpr; diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index e483f935b75c..ffa58e385322 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -67,7 +67,7 @@ impl ConstExpr { pub fn new(expr: Arc) -> Self { Self { expr, - // By default, assume constant expressions are not same accross partitions. + // By default, assume constant expressions are not same across partitions. across_partitions: false, } } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 83f94057f740..b9228282b081 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -30,7 +30,9 @@ mod properties; pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; -pub use properties::{join_equivalence_properties, EquivalenceProperties}; +pub use properties::{ + calculate_union, join_equivalence_properties, EquivalenceProperties, +}; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 8c327fbaf409..64c22064d4b7 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -21,7 +21,8 @@ use std::sync::Arc; use super::ordering::collapse_lex_ordering; use crate::equivalence::class::const_exprs_contains; use crate::equivalence::{ - collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + collapse_lex_req, EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + ProjectionMapping, }; use crate::expressions::Literal; use crate::{ @@ -32,11 +33,12 @@ use crate::{ use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; +use datafusion_common::{plan_err, JoinSide, JoinType, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::expressions::column::Column; use datafusion_physical_expr_common::expressions::CastExpr; +use datafusion_physical_expr_common::physical_expr::with_new_schema; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; @@ -536,33 +538,6 @@ impl EquivalenceProperties { .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) } - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - /// we substitute the ordering according to input expression type, this is a simplified version /// In this case, we just substitute when the expression satisfy the following condition: /// I. just have one column and is a CAST expression @@ -1007,6 +982,74 @@ impl EquivalenceProperties { .map(|node| node.data) .unwrap_or(ExprProperties::new_unknown()) } + + /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` + /// by mapping columns in the original schema to columns in the new schema + /// by index. + pub fn with_new_schema(self, schema: SchemaRef) -> Result { + // The new schema and the original schema is aligned when they have the + // same number of columns, and fields at the same index have the same + // type in both schemas. + let schemas_aligned = (self.schema.fields.len() == schema.fields.len()) + && self + .schema + .fields + .iter() + .zip(schema.fields.iter()) + .all(|(lhs, rhs)| lhs.data_type().eq(rhs.data_type())); + if !schemas_aligned { + // Rewriting equivalence properties in terms of new schema is not + // safe when schemas are not aligned: + return plan_err!( + "Cannot rewrite old_schema:{:?} with new schema: {:?}", + self.schema, + schema + ); + } + // Rewrite constants according to new schema: + let new_constants = self + .constants + .into_iter() + .map(|const_expr| { + let across_partitions = const_expr.across_partitions(); + let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; + Ok(ConstExpr::new(new_const_expr) + .with_across_partitions(across_partitions)) + }) + .collect::>>()?; + + // Rewrite orderings according to new schema: + let mut new_orderings = vec![]; + for ordering in self.oeq_class.orderings { + let new_ordering = ordering + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; + Ok(sort_expr) + }) + .collect::>()?; + new_orderings.push(new_ordering); + } + + // Rewrite equivalence classes according to the new schema: + let mut eq_classes = vec![]; + for eq_class in self.eq_group.classes { + let new_eq_exprs = eq_class + .into_vec() + .into_iter() + .map(|expr| with_new_schema(expr, &schema)) + .collect::>()?; + eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + } + + // Construct the resulting equivalence properties: + let mut result = EquivalenceProperties::new(schema); + result.constants = new_constants; + result.add_new_orderings(new_orderings); + result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + + Ok(result) + } } /// Calculates the properties of a given [`ExprPropertiesNode`]. @@ -1484,6 +1527,84 @@ impl Hash for ExprWrapper { } } +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. +fn calculate_union_binary( + lhs: EquivalenceProperties, + mut rhs: EquivalenceProperties, +) -> Result { + // TODO: In some cases, we should be able to preserve some equivalence + // classes. Add support for such cases. + + // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): + if !rhs.schema.eq(&lhs.schema) { + rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; + } + + // First, calculate valid constants for the union. A quantity is constant + // after the union if it is constant in both sides. + let constants = lhs + .constants() + .iter() + .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) + .map(|const_expr| { + // TODO: When both sides' constants are valid across partitions, + // the union's constant should also be valid if values are + // the same. However, we do not have the capability to + // check this yet. + ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) + }) + .collect(); + + // Next, calculate valid orderings for the union by searching for prefixes + // in both sides. + let mut orderings = vec![]; + for mut ordering in lhs.normalized_oeq_class().orderings { + // Progressively shorten the ordering to search for a satisfied prefix: + while !rhs.ordering_satisfy(&ordering) { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + for mut ordering in rhs.normalized_oeq_class().orderings { + // Progressively shorten the ordering to search for a satisfied prefix: + while !lhs.ordering_satisfy(&ordering) { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.constants = constants; + eq_properties.add_new_orderings(orderings); + Ok(eq_properties) +} + +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of the given `EquivalenceProperties` in `eqps` according to the given +/// output `schema` (which need not be the same with those of `lhs` and `rhs` +/// as details such as nullability may be different). +pub fn calculate_union( + eqps: Vec, + schema: SchemaRef, +) -> Result { + // TODO: In some cases, we should be able to preserve some equivalence + // classes. Add support for such cases. + let mut init = eqps[0].clone(); + // Harmonize the schema of the init with the schema of the union: + if !init.schema.eq(&schema) { + init = init.with_new_schema(schema)?; + } + eqps.into_iter() + .skip(1) + .try_fold(init, calculate_union_binary) +} + #[cfg(test)] mod tests { use std::ops::Not; @@ -2188,50 +2309,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_get_finer() -> Result<()> { let schema = create_test_schema()?; @@ -2525,4 +2602,422 @@ mod tests { Ok(()) } + + fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef { + Arc::new(Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + // Annotate name with `text`: + format!("{}{}", field.name(), text), + field.data_type().clone(), + field.is_nullable(), + ) + }) + .collect::>(), + )) + } + + #[tokio::test] + async fn test_union_equivalence_properties_multi_children() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + let test_cases = vec![ + // --------- TEST CASE 1 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 2 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b", "c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b", "c"]], + ), + // --------- TEST CASE 3 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1", "c1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["a2", "b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![vec!["a", "b"]], + ), + // --------- TEST CASE 4 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"]], + Arc::clone(&schema2), + ), + // Children 3 + ( + // Orderings + vec![vec!["b2", "c2"]], + Arc::clone(&schema3), + ), + ], + // Expected + vec![], + ), + // --------- TEST CASE 5 ---------- + ( + vec![ + // Children 1 + ( + // Orderings + vec![vec!["a", "b"], vec!["c"]], + Arc::clone(&schema), + ), + // Children 2 + ( + // Orderings + vec![vec!["a1", "b1"], vec!["c1"]], + Arc::clone(&schema2), + ), + ], + // Expected + vec![vec!["a", "b"], vec!["c"]], + ), + ]; + for (children, expected) in test_cases { + let children_eqs = children + .iter() + .map(|(orderings, schema)| { + let orderings = orderings + .iter() + .map(|ordering| { + ordering + .iter() + .map(|name| PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + EquivalenceProperties::new_with_orderings( + Arc::clone(schema), + &orderings, + ) + }) + .collect::>(); + let actual = calculate_union(children_eqs, Arc::clone(&schema))?; + + let expected_ordering = expected + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|name| PhysicalSortExpr { + expr: col(name, &schema).unwrap(), + options: SortOptions::default(), + }) + .collect::>() + }) + .collect::>(); + let expected = EquivalenceProperties::new_with_orderings( + Arc::clone(&schema), + &expected_ordering, + ); + assert_eq_properties_same( + &actual, + &expected, + format!("expected: {:?}, actual: {:?}", expected, actual), + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_union_equivalence_properties_binary() -> Result<()> { + let schema = create_test_schema()?; + let schema2 = append_fields(&schema, "1"); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_a1 = &col("a1", &schema2)?; + let col_b1 = &col("b1", &schema2)?; + let options = SortOptions::default(); + let options_desc = !SortOptions::default(); + let test_cases = [ + //-----------TEST CASE 1----------// + ( + ( + // First child orderings + vec![ + // [a ASC] + (vec![(col_a, options)]), + ], + // First child constants + vec![col_b, col_c], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [b ASC] + (vec![(col_b, options)]), + ], + // Second child constants + vec![col_a, col_c], + Arc::clone(&schema), + ), + ( + // Union expected orderings + vec![ + // [a ASC] + vec![(col_a, options)], + // [b ASC] + vec![(col_b, options)], + ], + // Union + vec![col_c], + ), + ), + //-----------TEST CASE 2----------// + // Meet ordering between [a ASC], [a ASC, b ASC] should be [a ASC] + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, options), (col_b, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Union orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + ), + ), + //-----------TEST CASE 3----------// + // Meet ordering between [a ASC], [a DESC] should be [] + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a DESC] + vec![(col_a, options_desc)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Union doesn't have any ordering + vec![], + // No constant + vec![], + ), + ), + //-----------TEST CASE 4----------// + // Meet ordering between [a ASC], [a1 ASC, b1 ASC] should be [a ASC] + // Where a, and a1 ath the same index for their corresponding schemas. + ( + ( + // First child orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + Arc::clone(&schema), + ), + ( + // Second child orderings + vec![ + // [a1 ASC, b1 ASC] + vec![(col_a1, options), (col_b1, options)], + ], + // No constant + vec![], + Arc::clone(&schema2), + ), + ( + // Union orderings + vec![ + // [a ASC] + vec![(col_a, options)], + ], + // No constant + vec![], + ), + ), + ]; + + for ( + test_idx, + ( + (first_child_orderings, first_child_constants, first_schema), + (second_child_orderings, second_child_constants, second_schema), + (union_orderings, union_constants), + ), + ) in test_cases.iter().enumerate() + { + let first_orderings = first_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let first_constants = first_child_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut lhs = EquivalenceProperties::new(Arc::clone(first_schema)); + lhs = lhs.add_constants(first_constants); + lhs.add_new_orderings(first_orderings); + + let second_orderings = second_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let second_constants = second_child_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut rhs = EquivalenceProperties::new(Arc::clone(second_schema)); + rhs = rhs.add_constants(second_constants); + rhs.add_new_orderings(second_orderings); + + let union_expected_orderings = union_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let union_constants = union_constants + .iter() + .map(|expr| ConstExpr::new(Arc::clone(expr))) + .collect::>(); + let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); + union_expected_eq = union_expected_eq.add_constants(union_constants); + union_expected_eq.add_new_orderings(union_expected_orderings); + + let actual_union_eq = calculate_union_binary(lhs, rhs)?; + let err_msg = format!( + "Error in test id: {:?}, test case: {:?}", + test_idx, test_cases[test_idx] + ); + assert_eq_properties_same(&actual_union_eq, &union_expected_eq, err_msg); + } + Ok(()) + } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether constants are same + let lhs_constants = lhs.constants(); + let rhs_constants = rhs.constants(); + assert_eq!(lhs_constants.len(), rhs_constants.len(), "{}", err_msg); + for rhs_constant in rhs_constants { + assert!( + const_exprs_contains(lhs_constants, rhs_constant.expr()), + "{}", + err_msg + ); + } + + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + for rhs_ordering in rhs_orderings { + assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + } + } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index cd73c5cb579c..b428d562bd1b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -32,10 +32,38 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; use itertools::Itertools; type WhenThen = (Arc, Arc); +#[derive(Debug, Hash)] +enum EvalMethod { + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + NoExpression, + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + WithExpression, + /// This is a specialization for a specific use case where we can take a fast path + /// for expressions that are infallible and can be cheaply computed for the entire + /// record batch rather than just for the rows where the predicate is true. + /// + /// CASE WHEN condition THEN column [ELSE NULL] END + InfallibleExprOrNull, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` expressions + /// are literal values + /// CASE WHEN condition THEN literal ELSE literal END + ScalarOrScalar, +} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -61,6 +89,8 @@ pub struct CaseExpr { when_then_expr: Vec, /// Optional "else" expression else_expr: Option>, + /// Evaluation method to use + eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { @@ -79,6 +109,15 @@ impl std::fmt::Display for CaseExpr { } } +/// This is a specialization for a specific use case where we can take a fast path +/// for expressions that are infallible and can be cheaply computed for the entire +/// record batch rather than just for the rows where the predicate is true. For now, +/// this is limited to use with Column expressions but could potentially be used for other +/// expressions in the future +fn is_cheap_and_infallible(expr: &Arc) -> bool { + expr.as_any().is::() +} + impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( @@ -86,13 +125,41 @@ impl CaseExpr { when_then_expr: Vec, else_expr: Option>, ) -> Result { + // normalize null literals to None in the else_expr (this already happens + // during SQL planning, but not necessarily for other use cases) + let else_expr = match &else_expr { + Some(e) => match e.as_any().downcast_ref::() { + Some(lit) if lit.value().is_null() => None, + _ => else_expr, + }, + _ => else_expr, + }; + if when_then_expr.is_empty() { exec_err!("There must be at least one WHEN clause") } else { + let eval_method = if expr.is_some() { + EvalMethod::WithExpression + } else if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else { + EvalMethod::NoExpression + }; + Ok(Self { expr, when_then_expr, else_expr, + eval_method, }) } } @@ -256,6 +323,70 @@ impl CaseExpr { Ok(ColumnarValue::Array(current_value)) } + + /// This function evaluates the specialized case of: + /// + /// CASE WHEN condition THEN column + /// [ELSE NULL] + /// END + /// + /// Note that this function is only safe to use for "then" expressions + /// that are infallible because the expression will be evaluated for all + /// rows in the input batch. + fn case_column_or_null(&self, batch: &RecordBatch) -> Result { + let when_expr = &self.when_then_expr[0].0; + let then_expr = &self.when_then_expr[0].1; + if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = not(bit_mask)?; + match then_expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) + } + ColumnarValue::Scalar(_) => { + internal_err!("expression did not evaluate to an array") + } + } + } else { + internal_err!("predicate did not evaluate to an array") + } + } + + fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when expression + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + // keep `else_expr`'s data type and return type consistent + let e = self.else_expr.as_ref().unwrap(); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } } impl PhysicalExpr for CaseExpr { @@ -303,14 +434,22 @@ impl PhysicalExpr for CaseExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - if self.expr.is_some() { - // this use case evaluates "expr" and then compares the values with the "when" - // values - self.case_when_with_expr(batch) - } else { - // The "when" conditions all evaluate to boolean in this use case and can be - // arbitrary expressions - self.case_when_no_expr(batch) + match self.eval_method { + EvalMethod::WithExpression => { + // this use case evaluates "expr" and then compares the values with the "when" + // values + self.case_when_with_expr(batch) + } + EvalMethod::NoExpression => { + // The "when" conditions all evaluate to boolean in this use case and can be + // arbitrary expressions + self.case_when_no_expr(batch) + } + EvalMethod::InfallibleExprOrNull => { + // Specialization for CASE WHEN expr THEN column [ELSE NULL] END + self.case_column_or_null(batch) + } + EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), } } @@ -409,7 +548,7 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit}; + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; @@ -419,6 +558,7 @@ mod tests { use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; + use datafusion_physical_expr_common::expressions::Literal; #[test] fn case_with_expr() -> Result<()> { @@ -931,7 +1071,7 @@ mod tests { } #[test] - fn case_tranform() -> Result<()> { + fn case_transform() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let when1 = lit("foo"); @@ -998,6 +1138,53 @@ mod tests { Ok(()) } + #[test] + fn test_column_or_null_specialization() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + match expr.evaluate(&batch)? { + ColumnarValue::Array(array) => { + assert_eq!(1000, array.len()); + assert_eq!(785, array.null_count()); + } + _ => unreachable!(), + } + Ok(()) + } + + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } + fn generate_case_when_with_type_coercion( expr: Option>, when_thens: Vec, diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7d8f12091f46..7cbe4e796844 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -20,7 +20,6 @@ #[macro_use] mod binary; mod case; -mod column; mod in_list; mod is_not_null; mod is_null; @@ -29,14 +28,12 @@ mod negative; mod no_op; mod not; mod try_cast; +mod unknown_column; /// Module with some convenient methods used in expression building pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; -pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; -pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::stats::StatsType; @@ -50,7 +47,6 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; @@ -63,160 +59,4 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; - -#[cfg(test)] -pub(crate) mod tests { - use std::sync::Arc; - - use crate::AggregateExpr; - - use arrow::record_batch::RecordBatch; - use datafusion_common::{Result, ScalarValue}; - - /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the - /// result. - #[macro_export] - macro_rules! generic_test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// Same as [`generic_test_op`] but with support for providing a 4th argument, usually - /// a boolean to indicate if using the distinct version of the op. - #[macro_export] - macro_rules! generic_test_distinct_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr) => { - generic_test_distinct_op!( - $ARRAY, - $DATATYPE, - $OP, - $DISTINCT, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - $DISTINCT, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. - /// - /// The difference between this and the above `generic_test_op` is that the former checks - /// the old slow-path [`datafusion_expr::Accumulator`] implementation, while this checks - /// the new [`crate::GroupsAccumulator`] implementation. - #[macro_export] - macro_rules! generic_test_op_new { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op_new!( - $ARRAY, - $DATATYPE, - $OP, - $EXPECTED, - $EXPECTED.data_type().clone() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate_new(&batch, agg)?; - assert_eq!($EXPECTED, &actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation with two inputs and verify the result. - #[macro_export] - macro_rules! generic_test_op2 { - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op2!( - $ARRAY1, - $ARRAY2, - $DATATYPE1, - $DATATYPE2, - $OP, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![ - Field::new("a", $DATATYPE1, true), - Field::new("b", $DATATYPE2, true), - ]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY1, $ARRAY2])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - - pub fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } -} +pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 3549a3df83bb..43b6c993d2b2 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -31,7 +31,7 @@ use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -/// TRY_CAST expression casts an expression to a specific data type and retuns NULL on invalid cast +/// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast #[derive(Debug, Hash)] pub struct TryCastExpr { /// The expression to cast diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs similarity index 51% rename from datafusion/physical-expr/src/expressions/column.rs rename to datafusion/physical-expr/src/expressions/unknown_column.rs index f6525c7c0462..cb7221e7fa15 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! Column expression +//! UnKnownColumn expression use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::{ @@ -67,7 +66,7 @@ impl PhysicalExpr for UnKnownColumn { Ok(DataType::Null) } - /// Decide whehter this expression is nullable, given the schema of the input + /// Decide whether this expression is nullable, given the schema of the input fn nullable(&self, _input_schema: &Schema) -> Result { Ok(true) } @@ -95,56 +94,9 @@ impl PhysicalExpr for UnKnownColumn { } impl PartialEq for UnKnownColumn { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } -} - -#[cfg(test)] -mod test { - use crate::expressions::Column; - use crate::PhysicalExpr; - - use arrow::array::StringArray; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::Result; - - use std::sync::Arc; - - #[test] - fn out_of_bounds_data_type() { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let col = Column::new("id", 9); - let error = col.data_type(&schema).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) - } - - #[test] - fn out_of_bounds_nullable() { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let col = Column::new("id", 9); - let error = col.nullable(&schema).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) - } - - #[test] - fn out_of_bounds_evaluate() -> Result<()> { - let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); - let data: StringArray = vec!["data"].into(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let col = Column::new("id", 9); - let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); - assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); - Ok(()) + fn eq(&self, _other: &dyn Any) -> bool { + // UnknownColumn is not a valid expression, so it should not be equal to any other expression. + // See https://github.com/apache/datafusion/pull/11536 + false } } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 4f83ae01959b..2e78119eba46 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -48,7 +48,7 @@ pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use datafusion_physical_expr_common::aggregate::{ AggregateExpr, AggregatePhysicalExpressions, }; -pub use equivalence::{ConstExpr, EquivalenceProperties}; +pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 821b2c9fe17a..6472dd47489c 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -84,9 +84,9 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// └──────────┐│┌──────────┘ /// │││ /// │││ -/// RepartitionExec with one input -/// that has 3 partitions, but 3 (async) streams, that internally -/// itself has only 1 output partition pull from the same input stream +/// RepartitionExec with 1 input +/// partition and 3 output partitions 3 (async) streams, that internally +/// pull from the same input stream /// ... /// ``` /// diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 42e5e6fcf3ac..993ff5610063 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -283,7 +283,7 @@ impl<'a> GuaranteeBuilder<'a> { ) } - /// Aggregates a new single column, multi literal term to ths builder + /// Aggregates a new single column, multi literal term to this builder /// combining with previously known guarantees if possible. /// /// # Examples diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml new file mode 100644 index 000000000000..125ea6acc77f --- /dev/null +++ b/datafusion/physical-optimizer/Cargo.toml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-physical-optimizer" +description = "DataFusion Physical Optimizer" +keywords = ["datafusion", "query", "optimizer"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +datafusion-common = { workspace = true, default-features = true } +datafusion-execution = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } diff --git a/datafusion/physical-optimizer/README.md b/datafusion/physical-optimizer/README.md new file mode 100644 index 000000000000..eb361d3f6779 --- /dev/null +++ b/datafusion/physical-optimizer/README.md @@ -0,0 +1,25 @@ + + +# DataFusion Physical Optimizer + +DataFusion is an extensible query execution framework, written in Rust, +that uses Apache Arrow as its in-memory format. + +This crate contains the physical optimizer for DataFusion. diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs new file mode 100644 index 000000000000..6b9df7cad5c8 --- /dev/null +++ b/datafusion/physical-optimizer/src/lib.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +mod optimizer; +pub mod output_requirements; + +pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs new file mode 100644 index 000000000000..885dc4a95b8c --- /dev/null +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical optimizer traits + +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which +/// computes the same results, but in a potentially more efficient way. +/// +/// Use [`SessionState::add_physical_optimizer_rule`] to register additional +/// `PhysicalOptimizerRule`s. +/// +/// [`SessionState::add_physical_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_physical_optimizer_rule +pub trait PhysicalOptimizerRule { + /// Rewrite `plan` to an optimized form + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result>; + + /// A human readable name for this optimizer rule + fn name(&self) -> &str; + + /// A flag to indicate whether the physical planner should valid the rule will not + /// change the schema of the plan after the rewriting. + /// Some of the optimization rules might change the nullable properties of the schema + /// and should disable the schema check. + fn schema_check(&self) -> bool; +} diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs similarity index 94% rename from datafusion/core/src/physical_optimizer/output_requirements.rs rename to datafusion/physical-optimizer/src/output_requirements.rs index 671bb437d5fa..f971d8f1f0aa 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -24,9 +24,11 @@ use std::sync::Arc; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, +}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -35,6 +37,8 @@ use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequire use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; +use crate::PhysicalOptimizerRule; + /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors /// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of @@ -86,7 +90,7 @@ enum RuleMode { /// /// See [`OutputRequirements`] for more details #[derive(Debug)] -pub(crate) struct OutputRequirementExec { +pub struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, @@ -94,7 +98,7 @@ pub(crate) struct OutputRequirementExec { } impl OutputRequirementExec { - pub(crate) fn new( + pub fn new( input: Arc, requirements: Option, dist_requirement: Distribution, @@ -108,8 +112,8 @@ impl OutputRequirementExec { } } - pub(crate) fn input(&self) -> Arc { - self.input.clone() + pub fn input(&self) -> Arc { + Arc::clone(&self.input) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -179,8 +183,8 @@ impl ExecutionPlan for OutputRequirementExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -275,7 +279,7 @@ fn require_top_ordering_helper( // When an operator requires an ordering, any `SortExec` below can not // be responsible for (i.e. the originator of) the global ordering. let (new_child, is_changed) = - require_top_ordering_helper(children.swap_remove(0).clone())?; + require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?; Ok((plan.with_new_children(vec![new_child])?, is_changed)) } else { // Stop searching, there is no global ordering desired for the query. diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 96a12d7b62da..8c2a4ba5c497 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -190,7 +190,7 @@ impl GroupValues for GroupValuesRows { let groups_rows = group_values.iter().take(n); let output = self.row_converter.convert_rows(groups_rows)?; // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficent? + // TODO file some ticket in arrow-rs to make this more efficient? let mut new_group_values = self.row_converter.empty_rows(0, 0); for row in group_values.iter().skip(n) { new_group_values.push(row); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 5f780f1ff801..e7cd5cb2725b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -75,12 +75,12 @@ pub enum AggregateMode { /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is a single partition (like Final) + /// This mode requires that the input is a single partition (like Final) Single, /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) + /// This mode requires that the input is partitioned by group key (like FinalPartitioned) SinglePartitioned, } @@ -733,7 +733,7 @@ impl ExecutionPlan for AggregateExec { // - once expressions will be able to compute their own stats, use it here // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: - // - aggregations somtimes also preserve invariants such as min, max... + // - aggregations sometimes also preserve invariants such as min, max... let column_statistics = Statistics::unknown_column(&self.schema()); match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned @@ -1194,22 +1194,25 @@ mod tests { use arrow::datatypes::DataType; use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, - ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, + DataFusionError, ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; use datafusion_functions_aggregate::median::median_udaf; - use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; + use datafusion_physical_expr::expressions::lit; use datafusion_physical_expr::PhysicalSortExpr; use crate::common::collect; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::{ + create_aggregate_expr, create_aggregate_expr_with_dfschema, + }; use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; @@ -1258,19 +1261,22 @@ mod tests { } /// Generates some mock data for aggregate tests. - fn some_data_v2() -> (Arc, Vec) { + fn some_data_v2() -> (Arc, DFSchemaRef, Vec) { // Define a schema: let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::Float64, false), ])); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); + // Generate data so that first and last value results are at 2nd and // 3rd partitions. With this construction, we guarantee we don't receive // the expected result by accident, but merging actually works properly; // i.e. it doesn't depend on the data insertion order. ( Arc::clone(&schema), + Arc::new(df_schema), vec![ RecordBatch::try_new( Arc::clone(&schema), @@ -1355,6 +1361,7 @@ mod tests { "COUNT(1)", false, false, + false, )?]; let task_ctx = if spill { @@ -1504,6 +1511,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; let task_ctx = if spill { @@ -1808,6 +1816,7 @@ mod tests { "MEDIAN(a)", false, false, + false, ) } @@ -1844,6 +1853,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; for (version, groups, aggregates) in [ @@ -1908,6 +1918,7 @@ mod tests { "AVG(a)", false, false, + false, )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); @@ -1952,6 +1963,7 @@ mod tests { "AVG(b)", false, false, + false, )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); @@ -1996,12 +2008,11 @@ mod tests { // FIRST_VALUE(b ORDER BY b ) fn test_first_value_agg_expr( schema: &Schema, + dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::Expr::Column( - datafusion_common::Column::new(Some("table1"), "b"), - )), + expr: Box::new(datafusion_expr::col("b")), asc: !sort_options.descending, nulls_first: sort_options.nulls_first, })]; @@ -2012,28 +2023,28 @@ mod tests { let args = vec![col("b", schema)?]; let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( + datafusion_physical_expr_common::aggregate::create_aggregate_expr_with_dfschema( &func, &args, &logical_args, &sort_exprs, &ordering_req, - schema, + dfschema, "FIRST_VALUE(b)", false, false, + false, ) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, + dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::Expr::Column( - datafusion_common::Column::new(Some("table1"), "b"), - )), + expr: Box::new(datafusion_expr::col("b")), asc: !sort_options.descending, nulls_first: sort_options.nulls_first, })]; @@ -2044,16 +2055,17 @@ mod tests { let args = vec![col("b", schema)?]; let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - create_aggregate_expr( + create_aggregate_expr_with_dfschema( &func, &args, &logical_args, &sort_exprs, &ordering_req, - schema, + dfschema, "LAST_VALUE(b)", false, false, + false, ) } @@ -2086,7 +2098,7 @@ mod tests { Arc::new(TaskContext::default()) }; - let (schema, data) = some_data_v2(); + let (schema, df_schema, data) = some_data_v2(); let partition1 = data[0].clone(); let partition2 = data[1].clone(); let partition3 = data[2].clone(); @@ -2100,9 +2112,13 @@ mod tests { nulls_first: false, }; let aggregates: Vec> = if is_first_acc { - vec![test_first_value_agg_expr(&schema, sort_options)?] + vec![test_first_value_agg_expr( + &schema, + &df_schema, + sort_options, + )?] } else { - vec![test_last_value_agg_expr(&schema, sort_options)?] + vec![test_last_value_agg_expr(&schema, &df_schema, sort_options)?] }; let memory_exec = Arc::new(MemoryExec::try_new( @@ -2169,6 +2185,8 @@ mod tests { #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; + let test_df_schema = DFSchema::try_from(Arc::clone(&test_schema)).unwrap(); + // Assume column a and b are aliases // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). let options1 = SortOptions { @@ -2178,7 +2196,7 @@ mod tests { let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. eq_properties.add_equal_conditions(col_a, col_b)?; // Aggregate requirements are @@ -2214,6 +2232,46 @@ mod tests { }, ]), ]; + let col_expr_a = Box::new(datafusion_expr::col("a")); + let col_expr_b = Box::new(datafusion_expr::col("b")); + let col_expr_c = Box::new(datafusion_expr::col("c")); + let sort_exprs = vec![ + None, + Some(vec![datafusion_expr::Expr::Sort(Sort::new( + col_expr_a.clone(), + options1.descending, + options1.nulls_first, + ))]), + Some(vec![ + datafusion_expr::Expr::Sort(Sort::new( + col_expr_a.clone(), + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_b.clone(), + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_c, + options1.descending, + options1.nulls_first, + )), + ]), + Some(vec![ + datafusion_expr::Expr::Sort(Sort::new( + col_expr_a, + options1.descending, + options1.nulls_first, + )), + datafusion_expr::Expr::Sort(Sort::new( + col_expr_b, + options1.descending, + options1.nulls_first, + )), + ]), + ]; let common_requirement = vec![ PhysicalSortExpr { expr: Arc::clone(col_a), @@ -2226,14 +2284,23 @@ mod tests { ]; let mut aggr_exprs = order_by_exprs .into_iter() - .map(|order_by_expr| { - Arc::new(OrderSensitiveArrayAgg::new( - Arc::clone(col_a), + .zip(sort_exprs.into_iter()) + .map(|(order_by_expr, sort_exprs)| { + let ordering_req = order_by_expr.unwrap_or_default(); + let sort_exprs = sort_exprs.unwrap_or_default(); + create_aggregate_expr_with_dfschema( + &array_agg_udaf(), + &[Arc::clone(col_a)], + &[], + &sort_exprs, + &ordering_req, + &test_df_schema, "array_agg", - DataType::Int32, - vec![], - order_by_expr.unwrap_or_default(), - )) as _ + false, + false, + false, + ) + .unwrap() }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); @@ -2254,6 +2321,7 @@ mod tests { Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), ])); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let col_a = col("a", &schema)?; let option_desc = SortOptions { @@ -2263,8 +2331,8 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); let aggregates: Vec> = vec![ - test_first_value_agg_expr(&schema, option_desc)?, - test_last_value_agg_expr(&schema, option_desc)?, + test_first_value_agg_expr(&schema, &df_schema, option_desc)?, + test_last_value_agg_expr(&schema, &df_schema, option_desc)?, ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2330,6 +2398,7 @@ mod tests { "1", false, false, + false, )?]; let input_batches = (0..4) diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index f85164f7f1e2..99417e4ee3e9 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -218,6 +218,7 @@ fn aggregate_batch( Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), None => Cow::Borrowed(&batch), }; + // 1.3 let values = &expr .iter() diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index a1d3378181c2..167ca7240750 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -31,8 +31,9 @@ use crate::common::IPCWriter; use crate::metrics::{BaselineMetrics, RecordOutput}; use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge; +use crate::spill::read_spill_as_stream; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, read_spill_as_stream, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index b4c1e25e6191..287446328f8d 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -206,7 +206,7 @@ impl ExecutionPlan for AnalyzeExec { } } -/// Creates the ouput of AnalyzeExec as a RecordBatch +/// Creates the output of AnalyzeExec as a RecordBatch fn create_output_batch( verbose: bool, show_statistics: bool, diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index bf9d14e73dd8..4b5eea6b760d 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -22,9 +22,9 @@ use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use super::{ExecutionPlanProperties, SendableRecordBatchStream}; +use super::SendableRecordBatchStream; use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::{ColumnStatistics, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; @@ -33,8 +33,6 @@ use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; -use datafusion_physical_expr::expressions::{BinaryExpr, Column}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -178,71 +176,6 @@ pub fn compute_record_batch_statistics( } } -/// Calculates the "meet" of given orderings. -/// The meet is the finest ordering that satisfied by all the given -/// orderings, see . -pub fn get_meet_of_orderings( - given: &[Arc], -) -> Option<&[PhysicalSortExpr]> { - given - .iter() - .map(|item| item.output_ordering()) - .collect::>>() - .and_then(get_meet_of_orderings_helper) -} - -fn get_meet_of_orderings_helper( - orderings: Vec<&[PhysicalSortExpr]>, -) -> Option<&[PhysicalSortExpr]> { - let mut idx = 0; - let first = orderings[0]; - loop { - for ordering in orderings.iter() { - if idx >= ordering.len() { - return Some(ordering); - } else { - let schema_aligned = check_expr_alignment( - ordering[idx].expr.as_ref(), - first[idx].expr.as_ref(), - ); - if !schema_aligned || (ordering[idx].options != first[idx].options) { - // In a union, the output schema is that of the first child (by convention). - // Therefore, generate the result from the first child's schema: - return if idx > 0 { Some(&first[..idx]) } else { None }; - } - } - } - idx += 1; - } - - fn check_expr_alignment(first: &dyn PhysicalExpr, second: &dyn PhysicalExpr) -> bool { - match ( - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - ) { - (Some(first_col), Some(second_col), _, _) => { - first_col.index() == second_col.index() - } - (_, _, Some(first_binary), Some(second_binary)) => { - if first_binary.op() == second_binary.op() { - check_expr_alignment( - first_binary.left().as_ref(), - second_binary.left().as_ref(), - ) && check_expr_alignment( - first_binary.right().as_ref(), - second_binary.right().as_ref(), - ) - } else { - false - } - } - (_, _, _, _) => false, - } - } -} - /// Write in Arrow IPC format. pub struct IPCWriter { /// path @@ -342,297 +275,12 @@ pub fn can_project( #[cfg(test)] mod tests { - use std::ops::Not; - use super::*; - use crate::memory::MemoryExec; - use crate::sorts::sort::SortExec; - use crate::union::UnionExec; - use arrow::compute::SortOptions; use arrow::{ array::{Float32Array, Float64Array, UInt64Array}, datatypes::{DataType, Field}, }; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::col; - - #[test] - fn get_meet_of_orderings_helper_common_prefix_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 2)), - options: SortOptions::default(), - }, - ]; - - let input4: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - // Note that index of this column is not 2. Hence this 3rd entry shouldn't be - // in the output ordering. - expr: Arc::new(Column::new("i", 3)), - options: SortOptions::default(), - }, - ]; - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), expected); - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input4]); - assert_eq!(result.unwrap(), expected); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_subset_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), input1); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_no_overlap_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - // Since ordering is conflicting with other inputs - // output ordering should be empty - options: SortOptions::default().not(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 1)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 2)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_binary_exprs() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Plus, - Arc::new(Column::new("b", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("x", 0)), - Operator::Plus, - Arc::new(Column::new("y", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - // erroneous input - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 1)), - Operator::Plus, - Arc::new(Column::new("b", 0)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert_eq!(input1, result.unwrap()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn test_meet_of_orderings() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("f32", DataType::Float32, false), - Field::new("f64", DataType::Float64, false), - ])); - let sort_expr = vec![PhysicalSortExpr { - expr: col("f32", &schema).unwrap(), - options: SortOptions::default(), - }]; - let memory_exec = - Arc::new(MemoryExec::try_new(&[], Arc::clone(&schema), None)?) as _; - let sort_exec = Arc::new(SortExec::new(sort_expr.clone(), memory_exec)) - as Arc; - let memory_exec2 = Arc::new(MemoryExec::try_new(&[], schema, None)?) as _; - // memory_exec2 doesn't have output ordering - let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), memory_exec2]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert!(res.is_none()); - - let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), sort_exec]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert_eq!(res, Some(&sort_expr[..])); - Ok(()) - } #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 7f4ae5797d97..0d2653c5c775 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -236,7 +236,7 @@ enum ShowMetrics { /// Do not show any metrics None, - /// Show aggregrated metrics across partition + /// Show aggregated metrics across partition Aggregated, /// Show full per-partition metrics diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c5ba3992d3b4..a9d78d059f5c 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -28,11 +28,11 @@ use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, ExecutionPlan, }; + use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, BooleanArray}; -use datafusion_common::cast::{as_boolean_array, as_null_array}; +use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -81,19 +81,6 @@ impl FilterExec { cache, }) } - DataType::Null => { - let default_selectivity = 0; - let cache = - Self::compute_properties(&input, &predicate, default_selectivity)?; - - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - }) - } other => { plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") } @@ -367,23 +354,15 @@ pub(crate) fn batch_filter( .evaluate(batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - let filter_array = match as_boolean_array(&array) { - Ok(boolean_array) => Ok(boolean_array.to_owned()), + Ok(match as_boolean_array(&array) { + // apply filter array to record batch + Ok(filter_array) => filter_record_batch(batch, filter_array)?, Err(_) => { - let Ok(null_array) = as_null_array(&array) else { - return internal_err!( - "Cannot create filter_array from non-boolean predicates" - ); - }; - - // if the predicate is null, then the result is also null - Ok::(BooleanArray::new_null( - null_array.len(), - )) + return internal_err!( + "Cannot create filter_array from non-boolean predicates" + ); } - }?; - - Ok(filter_record_batch(batch, &filter_array)?) + }) }) } diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 33a9c061bf31..8304ddc7331a 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -578,7 +578,7 @@ mod tests { } #[tokio::test] - async fn test_stats_cartesian_product_with_unknwon_size() { + async fn test_stats_cartesian_product_with_unknown_size() { let left_row_count = 11; let left = Statistics { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 754e55e49650..f8ca38980850 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -160,7 +160,7 @@ pub struct NestedLoopJoinExec { } impl NestedLoopJoinExec { - /// Try to create a nwe [`NestedLoopJoinExec`] + /// Try to create a new [`NestedLoopJoinExec`] pub fn try_new( left: Arc, right: Arc, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e9124a72970a..96d5ba728a30 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -24,40 +24,46 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::Formatter; +use std::fs::File; +use std::io::BufReader; use std::mem; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::spill::spill_record_batches; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -234,11 +240,6 @@ impl SortMergeJoinExec { impl DisplayAs for SortMergeJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let on = self @@ -250,7 +251,12 @@ impl DisplayAs for SortMergeJoinExec { write!( f, "SortMergeJoin: join_type={:?}, on=[{}]{}", - self.join_type, on, display_filter + self.join_type, + on, + self.filter.as_ref().map_or("".to_string(), |f| format!( + ", filter={}", + f.expression() + )) ) } } @@ -375,6 +381,7 @@ impl ExecutionPlan for SortMergeJoinExec { batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), reservation, + context.runtime_env(), )?)) } @@ -412,6 +419,12 @@ struct SortMergeJoinMetrics { /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, + /// count of spills during the execution of the operator + spill_count: Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, + /// total spilled rows during the execution of the operator + spilled_rows: Count, } impl SortMergeJoinMetrics { @@ -425,6 +438,9 @@ impl SortMergeJoinMetrics { MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + let spill_count = MetricBuilder::new(metrics).spill_count(partition); + let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition); + let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition); Self { join_time, @@ -433,6 +449,9 @@ impl SortMergeJoinMetrics { output_batches, output_rows, peak_mem_used, + spill_count, + spilled_bytes, + spilled_rows, } } } @@ -565,7 +584,8 @@ impl StreamedBatch { #[derive(Debug)] struct BufferedBatch { /// The buffered record batch - pub batch: RecordBatch, + /// None if the batch spilled to disk th + pub batch: Option, /// The range in which the rows share the same join key pub range: Range, /// Array refs of the join key @@ -577,6 +597,14 @@ struct BufferedBatch { /// The indices of buffered batch that failed the join filter. /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. pub join_filter_failed_idxs: HashSet, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive + pub num_rows: usize, + /// An optional temp spill file name on the disk if the batch spilled + /// None by default + /// Some(fileName) if the batch spilled to the disk + pub spill_file: Option, } impl BufferedBatch { @@ -602,13 +630,16 @@ impl BufferedBatch { + mem::size_of::>() + mem::size_of::(); + let num_rows = batch.num_rows(); BufferedBatch { - batch, + batch: Some(batch), range, join_arrays, null_joined: vec![], size_estimation, join_filter_failed_idxs: HashSet::new(), + num_rows, + spill_file: None, } } } @@ -634,7 +665,7 @@ struct SMJStream { pub buffered: SendableRecordBatchStream, /// Current processing record batch of streamed pub streamed_batch: StreamedBatch, - /// Currrent buffered data + /// Current buffered data pub buffered_data: BufferedData, /// (used in outer join) Is current streamed row joined at least once? pub streamed_joined: bool, @@ -666,6 +697,8 @@ struct SMJStream { pub join_metrics: SortMergeJoinMetrics, /// Memory reservation pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, } impl RecordBatchStream for SMJStream { @@ -785,6 +818,7 @@ impl SMJStream { batch_size: usize, join_metrics: SortMergeJoinMetrics, reservation: MemoryReservation, + runtime_env: Arc, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -813,6 +847,7 @@ impl SMJStream { join_type, join_metrics, reservation, + runtime_env, }) } @@ -858,6 +893,58 @@ impl SMJStream { } } + fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { + // Shrink memory usage for in-memory batches only + if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { + self.reservation + .try_shrink(buffered_batch.size_estimation)?; + } + + Ok(()) + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // spill buffered batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("sort_merge_join_buffered_spill")?; + + if let Some(batch) = buffered_batch.batch { + spill_record_batches( + vec![batch], + spill_file.path().into(), + Arc::clone(&self.buffered_schema), + )?; + buffered_batch.spill_file = Some(spill_file); + buffered_batch.batch = None; + + // update metrics to register spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + Ok(()) + } else { + internal_err!("Buffered batch has empty body") + } + } + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + /// Poll next buffered batches fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { loop { @@ -867,12 +954,12 @@ impl SMJStream { while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); // If the head batch is fully processed, dequeue it and produce output of it. - if head_batch.range.end == head_batch.batch.num_rows() { + if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - self.reservation.shrink(buffered_batch.size_estimation); + self.free_reservation(buffered_batch)?; } } else { // If the head batch is not fully processed, break the loop. @@ -900,25 +987,22 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + if batch.num_rows() > 0 { let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - self.reservation.try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; self.buffered_state = BufferedState::PollingRest; } } }, BufferedState::PollingRest => { if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { if is_join_arrays_equal( &self.buffered_data.head_batch().join_arrays, @@ -941,6 +1025,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { + // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -949,12 +1034,7 @@ impl SMJStream { 0..0, &self.on_buffered, ); - self.reservation - .try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; } } } @@ -1473,13 +1553,8 @@ fn produce_buffered_null_batch( } // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; + let buffered_columns = + get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; // Create null streamed (left) columns let mut streamed_columns = streamed_schema @@ -1502,13 +1577,45 @@ fn get_buffered_columns( buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, -) -> Result, ArrowError> { - buffered_data.batches[buffered_batch_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() +) -> Result> { + get_buffered_columns_from_batch( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn get_buffered_columns_from_batch( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result> { + match (&buffered_batch.spill_file, &buffered_batch.batch) { + // In memory batch + (None, Some(batch)) => Ok(batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?), + // If the batch was spilled to disk, less likely + (Some(spill_file), None) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = FileReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + // Invalid combination + (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), + } } /// Calculate join filter bit mask considering join type specifics @@ -1574,22 +1681,25 @@ fn get_filtered_join_mask( JoinType::LeftAnti => { // have we seen a filter match for a streaming index before for i in 0..streamed_indices_length { - if mask.value(i) && !seen_as_true { + let streamed_idx = streamed_indices.value(i); + if mask.value(i) + && !seen_as_true + && !matched_indices.contains(&streamed_idx) + { seen_as_true = true; - filter_matched_indices.push(streamed_indices.value(i)); + filter_matched_indices.push(streamed_idx); } // Reset `seen_as_true` flag and calculate mask for the current streaming index // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last if (i < streamed_indices_length - 1 - && streamed_indices.value(i) != streamed_indices.value(i + 1)) + && streamed_idx != streamed_indices.value(i + 1)) || (i == streamed_indices_length - 1 && *scanning_buffered_offset == 0) { corrected_mask.append_value( - !matched_indices.contains(&streamed_indices.value(i)) - && !seen_as_true, + !matched_indices.contains(&streamed_idx) && !seen_as_true, ); seen_as_true = false; } else { @@ -1854,6 +1964,7 @@ mod tests { assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_execution::TaskContext; @@ -2749,7 +2860,7 @@ mod tests { } #[tokio::test] - async fn overallocation_single_batch() -> Result<()> { + async fn overallocation_single_batch_no_spill() -> Result<()> { let left = build_table( ("a1", &vec![0, 1, 2, 3, 4, 5]), ("b1", &vec![1, 2, 3, 4, 5, 6]), @@ -2775,14 +2886,17 @@ mod tests { JoinType::LeftAnti, ]; - for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( @@ -2797,18 +2911,20 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); } Ok(()) } #[tokio::test] - async fn overallocation_multi_batch() -> Result<()> { + async fn overallocation_multi_batch_no_spill() -> Result<()> { let left_batch_1 = build_table_i32( ("a1", &vec![0, 1]), ("b1", &vec![1, 1]), @@ -2855,13 +2971,17 @@ mod tests { JoinType::LeftAnti, ]; + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( Arc::clone(&left), @@ -2875,11 +2995,205 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_single_batch_spill() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_multi_batch_spill() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(500, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } } Ok(()) diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index c23dc2032c4b..2299b7ff07f1 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -215,7 +215,7 @@ impl SymmetricHashJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); - // Error out if no "on" contraints are given: + // Error out if no "on" constraints are given: if on.is_empty() { return plan_err!( "On constraints in SymmetricHashJoinExec should be non-empty" diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index e3ec242ce8de..51744730a5a1 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -145,7 +145,7 @@ impl JoinHashMap { pub(crate) type JoinHashMapOffset = (usize, Option); // Macro for traversing chained values with limit. -// Early returns in case of reacing output tuples limit. +// Early returns in case of reaching output tuples limit. macro_rules! chain_traverse { ( $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, @@ -477,7 +477,7 @@ fn offset_ordering( offset: usize, ) -> Vec { match join_type { - // In the case below, right ordering should be offseted with the left + // In the case below, right ordering should be offsetted with the left // side length, since we append the right table to the left table. JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering .iter() @@ -910,7 +910,7 @@ fn estimate_inner_join_cardinality( left_stats: Statistics, right_stats: Statistics, ) -> Option> { - // Immediatedly return if inputs considered as non-overlapping + // Immediately return if inputs considered as non-overlapping if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) { return Some(estimation); }; @@ -2419,7 +2419,7 @@ mod tests { ); assert!( absent_outer_estimation.is_none(), - "Expected \"None\" esimated SemiJoin cardinality for absent outer num_rows" + "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows" ); let absent_inner_estimation = estimate_join_cardinality( @@ -2437,7 +2437,7 @@ mod tests { &join_on, ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); - assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows esimated SemiJoin cardinality for absent inner num_rows"); + assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"); let absent_inner_estimation = estimate_join_cardinality( &JoinType::LeftSemi, @@ -2453,7 +2453,7 @@ mod tests { }, &join_on, ); - assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated SemiJoin cardinality for absent outer and inner num_rows"); + assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"); Ok(()) } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index dc736993a453..c834005bb7c3 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -21,31 +21,41 @@ use std::any::Any; use std::fmt::Debug; -use std::fs::File; -use std::io::BufReader; -use std::path::{Path, PathBuf}; use std::sync::Arc; -use crate::coalesce_partitions::CoalescePartitionsExec; -use crate::display::DisplayableExecutionPlan; -use crate::metrics::MetricsSet; -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; - use arrow::datatypes::SchemaRef; -use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; +use futures::stream::{StreamExt, TryStreamExt}; +use tokio::task::JoinSet; + use datafusion_common::config::ConfigOptions; -use datafusion_common::{exec_datafusion_err, exec_err, Result}; +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +use datafusion_common::{exec_err, Result}; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; use datafusion_execution::TaskContext; +pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +pub use datafusion_physical_expr::{ + expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, +}; use datafusion_physical_expr::{ EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, }; -use futures::stream::{StreamExt, TryStreamExt}; -use log::debug; -use tokio::sync::mpsc::Sender; -use tokio::task::JoinSet; +use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::display::DisplayableExecutionPlan; +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub use crate::metrics::Metric; +use crate::metrics::MetricsSet; +pub use crate::ordering::InputOrderMode; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; +pub use crate::stream::EmptyRecordBatchStream; +use crate::stream::RecordBatchStreamAdapter; +pub use crate::topk::TopK; +pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; mod ordering; mod topk; @@ -70,6 +80,7 @@ pub mod projection; pub mod recursive_query; pub mod repartition; pub mod sorts; +pub mod spill; pub mod stream; pub mod streaming; pub mod tree_node; @@ -79,32 +90,9 @@ pub mod values; pub mod windows; pub mod work_table; -pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; -pub use crate::metrics::Metric; -pub use crate::ordering::InputOrderMode; -pub use crate::topk::TopK; -pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; - -pub use datafusion_common::hash_utils; -pub use datafusion_common::utils::project_schema; -pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; -pub use datafusion_expr::{Accumulator, ColumnarValue}; -pub use datafusion_physical_expr::window::WindowExpr; -pub use datafusion_physical_expr::{ - expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, -}; - -// Backwards compatibility -use crate::common::IPCWriter; -pub use crate::stream::EmptyRecordBatchStream; -use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; -use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::memory_pool::human_readable_size; -pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; - pub mod udaf { pub use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, AggregateFunctionExpr, + create_aggregate_expr, create_aggregate_expr_with_dfschema, AggregateFunctionExpr, }; } @@ -903,63 +891,13 @@ pub fn get_plan_string(plan: &Arc) -> Vec { actual.iter().map(|elem| elem.to_string()).collect() } -/// Read spilled batches from the disk -/// -/// `path` - temp file -/// `schema` - batches schema, should be the same across batches -/// `buffer` - internal buffer of capacity batches -pub fn read_spill_as_stream( - path: RefCountedTempFile, - schema: SchemaRef, - buffer: usize, -) -> Result { - let mut builder = RecordBatchReceiverStream::builder(schema, buffer); - let sender = builder.tx(); - - builder.spawn_blocking(move || read_spill(sender, path.path())); - - Ok(builder.build()) -} - -/// Spills in-memory `batches` to disk. -/// -/// Returns total number of the rows spilled to disk. -pub fn spill_record_batches( - batches: Vec, - path: PathBuf, - schema: SchemaRef, -) -> Result { - let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; - for batch in batches { - writer.write(&batch)?; - } - writer.finish()?; - debug!( - "Spilled {} batches of total {} rows to disk, memory released {}", - writer.num_batches, - writer.num_rows, - human_readable_size(writer.num_bytes), - ); - Ok(writer.num_rows) -} - -fn read_spill(sender: Sender>, path: &Path) -> Result<()> { - let file = BufReader::new(File::open(path)?); - let reader = FileReader::try_new(file, None)?; - for batch in reader { - sender - .blocking_send(batch.map_err(Into::into)) - .map_err(|e| exec_datafusion_err!("{e}"))?; - } - Ok(()) -} - #[cfg(test)] mod tests { use std::any::Any; use std::sync::Arc; use arrow_schema::{Schema, SchemaRef}; + use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 9c77a3d05cc2..f3dad6afabde 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -393,7 +393,7 @@ impl ExecutionPlan for LocalLimitExec { .. } if nr <= self.fetch => input_stats, // if the input is greater than the limit, the num_row will be greater - // than the limit because the partitions will be limited separatly + // than the limit because the partitions will be limited separately // the statistic Statistics { num_rows: Precision::Exact(nr), diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index e5c506403ff6..4870e9e95eb5 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1345,8 +1345,8 @@ mod tests { #[tokio::test] // As the hash results might be different on different platforms or - // wiht different compilers, we will compare the same execution with - // and without droping the output stream. + // with different compilers, we will compare the same execution with + // and without dropping the output stream. async fn hash_repartition_with_dropping_output_stream() { let task_ctx = Arc::new(TaskContext::default()); let partitioning = Partitioning::Hash( @@ -1357,7 +1357,7 @@ mod tests { 2, ); - // We first collect the results without droping the output stream. + // We first collect the results without dropping the output stream. let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new( Arc::clone(&input) as Arc, diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f347a0f5b6d5..13ff63c17405 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -30,13 +30,13 @@ use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; use crate::sorts::streaming_merge::streaming_merge; +use crate::spill::{read_spill_as_stream, spill_record_batches}; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; use crate::{ - read_spill_as_stream, spill_record_batches, DisplayAs, DisplayFormatType, - Distribution, EmptyRecordBatchStream, ExecutionMode, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, - Statistics, + DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionMode, + ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }; use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn}; @@ -45,7 +45,7 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; @@ -333,10 +333,7 @@ impl ExternalSorter { for spill in self.spills.drain(..) { if !spill.path().exists() { - return Err(DataFusionError::Internal(format!( - "Spill file {:?} does not exist", - spill.path() - ))); + return internal_err!("Spill file {:?} does not exist", spill.path()); } let stream = read_spill_as_stream(spill, Arc::clone(&self.schema), 2)?; streams.push(stream); @@ -602,7 +599,7 @@ pub fn sort_batch( .collect::>>()?; let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one colum + // lex_sort_to_indices doesn't support List with more than one column // https://github.com/apache/arrow-rs/issues/5454 lexsort_to_indices_multi_columns(sort_columns, fetch)? } else { @@ -802,12 +799,12 @@ impl DisplayAs for SortExec { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr = PhysicalSortExpr::format_list(&self.expr); - let preserve_partioning = self.preserve_partitioning; + let preserve_partitioning = self.preserve_partitioning; match self.fetch { Some(fetch) => { - write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}], preserve_partitioning=[{preserve_partioning}]",) + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}], preserve_partitioning=[{preserve_partitioning}]",) } - None => write!(f, "SortExec: expr=[{expr}], preserve_partitioning=[{preserve_partioning}]"), + None => write!(f, "SortExec: expr=[{expr}], preserve_partitioning=[{preserve_partitioning}]"), } } } diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs new file mode 100644 index 000000000000..21ca58fa0a9f --- /dev/null +++ b/datafusion/physical-plan/src/spill.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the spilling functions + +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; + +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::FileReader; +use arrow::record_batch::RecordBatch; +use log::debug; +use tokio::sync::mpsc::Sender; + +use datafusion_common::{exec_datafusion_err, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::human_readable_size; +use datafusion_execution::SendableRecordBatchStream; + +use crate::common::IPCWriter; +use crate::stream::RecordBatchReceiverStream; + +/// Read spilled batches from the disk +/// +/// `path` - temp file +/// `schema` - batches schema, should be the same across batches +/// `buffer` - internal buffer of capacity batches +pub(crate) fn read_spill_as_stream( + path: RefCountedTempFile, + schema: SchemaRef, + buffer: usize, +) -> Result { + let mut builder = RecordBatchReceiverStream::builder(schema, buffer); + let sender = builder.tx(); + + builder.spawn_blocking(move || read_spill(sender, path.path())); + + Ok(builder.build()) +} + +/// Spills in-memory `batches` to disk. +/// +/// Returns total number of the rows spilled to disk. +pub(crate) fn spill_record_batches( + batches: Vec, + path: PathBuf, + schema: SchemaRef, +) -> Result { + let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + for batch in batches { + writer.write(&batch)?; + } + writer.finish()?; + debug!( + "Spilled {} batches of total {} rows to disk, memory released {}", + writer.num_batches, + writer.num_rows, + human_readable_size(writer.num_bytes), + ); + Ok(writer.num_rows) +} + +fn read_spill(sender: Sender>, path: &Path) -> Result<()> { + let file = BufReader::new(File::open(path)?); + let reader = FileReader::try_new(file, None)?; + for batch in reader { + sender + .blocking_send(batch.map_err(Into::into)) + .map_err(|e| exec_datafusion_err!("{e}"))?; + } + Ok(()) +} + +/// Spill the `RecordBatch` to disk as smaller batches +/// split by `batch_size_rows` +/// Return `total_rows` what is spilled +pub fn spill_record_batch_by_size( + batch: &RecordBatch, + path: PathBuf, + schema: SchemaRef, + batch_size_rows: usize, +) -> Result<()> { + let mut offset = 0; + let total_rows = batch.num_rows(); + let mut writer = IPCWriter::new(&path, schema.as_ref())?; + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, batch_size_rows); + let batch = batch.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + writer.finish()?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::spill::{spill_record_batch_by_size, spill_record_batches}; + use crate::test::build_table_i32; + use datafusion_common::Result; + use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::DiskManager; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + + #[test] + fn test_batch_spill_and_read() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let batch2 = build_table_i32( + ("a2", &vec![10, 11, 12]), + ("b2", &vec![13, 14, 15]), + ("c2", &vec![14, 15, 16]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + let num_rows = batch1.num_rows() + batch2.num_rows(); + let cnt = spill_record_batches( + vec![batch1, batch2], + spill_file.path().into(), + Arc::clone(&schema), + ); + assert_eq!(cnt.unwrap(), num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 2); + assert_eq!(reader.schema(), schema); + + Ok(()) + } + + #[test] + fn test_batch_spill_by_size() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2, 3]), + ("b2", &vec![3, 4, 5, 6]), + ("c2", &vec![4, 5, 6, 7]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + spill_record_batch_by_size( + &batch1, + spill_file.path().into(), + Arc::clone(&schema), + 1, + )?; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 4); + assert_eq!(reader.schema(), schema); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 5a9035c8dbfc..e10e5c9a6995 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -80,7 +80,7 @@ impl StreamingTableExec { if !schema.eq(partition_schema) { debug!( "Target schema does not match with partition schema. \ - Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + Target_schema: {schema:?}. Partition Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index ac4eb1ca9e58..cf1c0e313733 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -725,7 +725,7 @@ pub struct PanicExec { schema: SchemaRef, /// Number of output partitions. Each partition will produce this - /// many empty output record batches prior to panicing + /// many empty output record batches prior to panicking batches_until_panics: Vec, cache: PlanProperties, } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 5366a5707696..d3f1a4fd96ca 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -94,7 +94,7 @@ pub struct TopK { impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. - // TOOD: make a builder or some other nicer API to avoid the + // TODO: make a builder or some other nicer API to avoid the // clippy warning #[allow(clippy::too_many_arguments)] pub fn try_new( @@ -258,7 +258,7 @@ impl TopKMetrics { /// Using the `Row` format handles things such as ascending vs /// descending and nulls first vs nulls last. struct TopKHeap { - /// The maximum number of elemenents to store in this heap. + /// The maximum number of elements to store in this heap. k: usize, /// The target number of rows for output batches batch_size: usize, @@ -421,7 +421,7 @@ impl TopKHeap { let num_rows = self.inner.len(); let (new_batch, mut topk_rows) = self.emit_with_state()?; - // clear all old entires in store (this invalidates all + // clear all old entries in store (this invalidates all // store_ids in `inner`) self.store.clear(); @@ -453,7 +453,7 @@ impl TopKHeap { /// Represents one of the top K rows held in this heap. Orders /// according to memcmp of row (e.g. the arrow Row format, but could -/// also be primtive values) +/// also be primitive values) /// /// Reuses allocations to minimize runtime overhead of creating new Vecs #[derive(Debug, PartialEq)] diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index b39c6aee82b9..9321fdb2cadf 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -41,7 +41,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +use datafusion_physical_expr::{calculate_union, EquivalenceProperties}; use futures::Stream; use itertools::Itertools; @@ -99,7 +99,12 @@ impl UnionExec { /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); - let cache = Self::compute_properties(&inputs, schema); + // The schema of the inputs and the union schema is consistent when: + // - They have the same number of fields, and + // - Their fields have same types at the same indices. + // Here, we know that schemas are consistent and the call below can + // not return an error. + let cache = Self::compute_properties(&inputs, schema).unwrap(); UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), @@ -116,13 +121,13 @@ impl UnionExec { fn compute_properties( inputs: &[Arc], schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let children_eqs = inputs + let children_eqps = inputs .iter() - .map(|child| child.equivalence_properties()) + .map(|child| child.equivalence_properties().clone()) .collect::>(); - let eq_properties = calculate_union_eq_properties(&children_eqs, schema); + let eq_properties = calculate_union(children_eqps, schema)?; // Calculate output partitioning; i.e. sum output partitions of the inputs. let num_partitions = inputs @@ -134,70 +139,12 @@ impl UnionExec { // Determine execution mode: let mode = execution_mode_from_children(inputs.iter()); - PlanProperties::new(eq_properties, output_partitioning, mode) - } -} -/// Calculate `EquivalenceProperties` for `UnionExec` from the `EquivalenceProperties` -/// of its children. -fn calculate_union_eq_properties( - children_eqs: &[&EquivalenceProperties], - schema: SchemaRef, -) -> EquivalenceProperties { - // Calculate equivalence properties: - // TODO: In some cases, we should be able to preserve some equivalence - // classes and constants. Add support for such cases. - let mut eq_properties = EquivalenceProperties::new(schema); - // Use the ordering equivalence class of the first child as the seed: - let mut meets = children_eqs[0] - .oeq_class() - .iter() - .map(|item| item.to_vec()) - .collect::>(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - // Compute meet orderings of the current meets and the new ordering - // equivalence class. - let mut idx = 0; - while idx < meets.len() { - // Find all the meets of `current_meet` with this child's orderings: - let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { - child_eqs.get_meet_ordering(ordering, &meets[idx]) - }); - // Use the longest of these meets as others are redundant: - if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { - meets[idx] = next_meet; - idx += 1; - } else { - meets.swap_remove(idx); - } - } - } - // We know have all the valid orderings after union, remove redundant - // entries (implicitly) and return: - eq_properties.add_new_orderings(meets); - - let mut meet_constants = children_eqs[0].constants().to_vec(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - let constants = child_eqs.constants(); - meet_constants = meet_constants - .into_iter() - .filter_map(|meet_constant| { - for const_expr in constants { - if const_expr.expr().eq(meet_constant.expr()) { - // TODO: Check whether constant expressions evaluates the same value or not for each partition - let across_partitions = false; - return Some( - ConstExpr::from(meet_constant.owned_expr()) - .with_across_partitions(across_partitions), - ); - } - } - None - }) - .collect::>(); + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + mode, + )) } - eq_properties.add_constants(meet_constants) } impl DisplayAs for UnionExec { @@ -431,6 +378,12 @@ impl ExecutionPlan for InterleaveExec { self: Arc, children: Vec>, ) -> Result> { + // New children are no longer interleavable, which might be a bug of optimization rewrite. + if !can_interleave(children.iter()) { + return internal_err!( + "Can not create InterleaveExec: new children can not be interleaved" + ); + } Ok(Arc::new(InterleaveExec::try_new(children)?)) } @@ -633,8 +586,8 @@ mod tests { use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; - use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use datafusion_physical_expr_common::expressions::column::col; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -850,23 +803,31 @@ mod tests { .with_sort_information(second_orderings), ); + let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); + union_expected_eq.add_new_orderings(union_expected_orderings); + let union = UnionExec::new(vec![child1, child2]); let union_eq_properties = union.properties().equivalence_properties(); - let union_actual_orderings = union_eq_properties.oeq_class(); let err_msg = format!( "Error in test id: {:?}, test case: {:?}", test_idx, test_cases[test_idx] ); - assert_eq!( - union_actual_orderings.len(), - union_expected_orderings.len(), - "{}", - err_msg - ); - for expected in &union_expected_orderings { - assert!(union_actual_orderings.contains(expected), "{}", err_msg); - } + assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg); } Ok(()) } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + for rhs_ordering in rhs_orderings { + assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + } + } } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 7f794556a241..959796489c19 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -157,6 +157,7 @@ pub fn create_window_expr( name, ignore_nulls, false, + false, )?; window_expr_from_aggregate_expr( partition_by, @@ -805,7 +806,7 @@ mod tests { } #[tokio::test] - async fn test_satisfiy_nullable() -> Result<()> { + async fn test_satisfy_nullable() -> Result<()> { let schema = create_test_schema()?; let params = vec![ ((true, true), (false, false), false), diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index b6330f65e0b7..1d5c6061a0f9 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -126,7 +126,7 @@ impl WindowAggExec { // Get output partitioning: // Because we can have repartitioning using the partition keys this - // would be either 1 or more than 1 depending on the presense of repartitioning. + // would be either 1 or more than 1 depending on the presence of repartitioning. let output_partitioning = input.output_partitioning().clone(); // Determine execution mode: diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 5f3cf6e2aee8..ba95640a87c7 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -225,7 +225,7 @@ mod tests { #[test] fn test_work_table() { let work_table = WorkTable::new(); - // cann't take from empty work_table + // can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 23228dd54591..1dd413b0d1a0 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -130,6 +130,12 @@ message Decimal{ int32 scale = 4; } +message Decimal256Type{ + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + message List{ Field field_type = 1; } @@ -335,6 +341,7 @@ message ArrowType{ TimeUnit TIME64 = 22 ; IntervalUnit INTERVAL = 23 ; Decimal DECIMAL = 24 ; + Decimal256Type DECIMAL256 = 36; List LIST = 25; List LARGE_LIST = 26; FixedSizeList FIXED_SIZE_LIST = 27; @@ -410,6 +417,7 @@ message CsvOptions { string null_value = 12; // Optional representation of null value bytes comment = 13; // Optional comment character as a byte bytes double_quote = 14; // Indicates if quotes are doubled + bytes newlines_in_values = 15; // Indicates if newlines are supported in values } // Options controlling CSV format diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index c84d8eea2a48..21db66a12701 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -261,6 +261,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { precision, scale, }) => DataType::Decimal128(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type { + precision, + scale, + }) => DataType::Decimal256(*precision as u8, *scale as i8), arrow_type::ArrowTypeEnum::List(list) => { let list_type = list.as_ref().field_type.as_deref().required("field_type")?; @@ -861,6 +865,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { quote: proto_opts.quote[0], escape: proto_opts.escape.first().copied(), double_quote: proto_opts.has_header.first().map(|h| *h != 0), + newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0), compression: proto_opts.compression().into(), schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, date_format: (!proto_opts.date_format.is_empty()) diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index be3cc58b23df..511072f3cb55 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType { arrow_type::ArrowTypeEnum::Decimal(v) => { struct_ser.serialize_field("DECIMAL", v)?; } + arrow_type::ArrowTypeEnum::Decimal256(v) => { + struct_ser.serialize_field("DECIMAL256", v)?; + } arrow_type::ArrowTypeEnum::List(v) => { struct_ser.serialize_field("LIST", v)?; } @@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64", "INTERVAL", "DECIMAL", + "DECIMAL256", "LIST", "LARGE_LIST", "LARGELIST", @@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Time64, Interval, Decimal, + Decimal256, List, LargeList, FixedSizeList, @@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME64" => Ok(GeneratedField::Time64), "INTERVAL" => Ok(GeneratedField::Interval), "DECIMAL" => Ok(GeneratedField::Decimal), + "DECIMAL256" => Ok(GeneratedField::Decimal256), "LIST" => Ok(GeneratedField::List), "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList), @@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType { return Err(serde::de::Error::duplicate_field("DECIMAL")); } arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) +; + } + GeneratedField::Decimal256 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL256")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256) ; } GeneratedField::List => { @@ -1884,6 +1897,9 @@ impl serde::Serialize for CsvOptions { if !self.double_quote.is_empty() { len += 1; } + if !self.newlines_in_values.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1936,6 +1952,10 @@ impl serde::Serialize for CsvOptions { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; } + if !self.newlines_in_values.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("newlinesInValues", pbjson::private::base64::encode(&self.newlines_in_values).as_str())?; + } struct_ser.end() } } @@ -1969,6 +1989,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "comment", "double_quote", "doubleQuote", + "newlines_in_values", + "newlinesInValues", ]; #[allow(clippy::enum_variant_names)] @@ -1987,6 +2009,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { NullValue, Comment, DoubleQuote, + NewlinesInValues, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2022,6 +2045,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "nullValue" | "null_value" => Ok(GeneratedField::NullValue), "comment" => Ok(GeneratedField::Comment), "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2055,6 +2079,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut null_value__ = None; let mut comment__ = None; let mut double_quote__ = None; + let mut newlines_in_values__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -2155,6 +2180,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } Ok(CsvOptions { @@ -2172,6 +2205,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { null_value: null_value__.unwrap_or_default(), comment: comment__.unwrap_or_default(), double_quote: double_quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), }) } } @@ -2828,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Decimal256Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DfField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index b0674ff28d75..62919e218b13 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] @@ -633,6 +643,9 @@ pub struct CsvOptions { /// Indicates if quotes are doubled #[prost(bytes = "vec", tag = "14")] pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 04ec1d182875..24083e8b7276 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -192,9 +192,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { precision: *precision as u32, scale: *scale as i32, }), - DataType::Decimal256(_, _) => { - return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned())) - } + DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }), DataType::Map(field, sorted) => { Self::Map(Box::new( protobuf::Map { @@ -903,6 +904,9 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { quote: vec![opts.quote], escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]), + newlines_in_values: opts + .newlines_in_values + .map_or_else(Vec::new, |h| vec![h as u8]), compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec as u64, date_format: opts.date_format.clone().unwrap_or_default(), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index dc551778c5fb..e133abd46f43 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -472,7 +472,7 @@ enum AggregateFunction { // AVG = 3; // COUNT = 4; // APPROX_DISTINCT = 5; - ARRAY_AGG = 6; + // ARRAY_AGG = 6; // VARIANCE = 7; // VARIANCE_POP = 8; // COVARIANCE = 9; @@ -1007,6 +1007,7 @@ message CsvScanExecNode { oneof optional_comment { string comment = 6; } + bool newlines_in_values = 7; } message AvroScanExecNode { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 56666306548f..a50103f1b8b0 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -140,6 +140,14 @@ pub struct Decimal { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -446,7 +454,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -516,6 +524,8 @@ pub mod arrow_type { Interval(i32), #[prost(message, tag = "24")] Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] List(::prost::alloc::boxed::Box), #[prost(message, tag = "26")] @@ -633,6 +643,9 @@ pub struct CsvOptions { /// Indicates if quotes are doubled #[prost(bytes = "vec", tag = "14")] pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8f77c24bd911..c5ec67d72875 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -534,7 +534,6 @@ impl serde::Serialize for AggregateFunction { let variant = match self { Self::Min => "MIN", Self::Max => "MAX", - Self::ArrayAgg => "ARRAY_AGG", }; serializer.serialize_str(variant) } @@ -548,7 +547,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { const FIELDS: &[&str] = &[ "MIN", "MAX", - "ARRAY_AGG", ]; struct GeneratedVisitor; @@ -591,7 +589,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { match value { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), - "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -3605,6 +3602,9 @@ impl serde::Serialize for CsvScanExecNode { if !self.quote.is_empty() { len += 1; } + if self.newlines_in_values { + len += 1; + } if self.optional_escape.is_some() { len += 1; } @@ -3624,6 +3624,9 @@ impl serde::Serialize for CsvScanExecNode { if !self.quote.is_empty() { struct_ser.serialize_field("quote", &self.quote)?; } + if self.newlines_in_values { + struct_ser.serialize_field("newlinesInValues", &self.newlines_in_values)?; + } if let Some(v) = self.optional_escape.as_ref() { match v { csv_scan_exec_node::OptionalEscape::Escape(v) => { @@ -3654,6 +3657,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "hasHeader", "delimiter", "quote", + "newlines_in_values", + "newlinesInValues", "escape", "comment", ]; @@ -3664,6 +3669,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { HasHeader, Delimiter, Quote, + NewlinesInValues, Escape, Comment, } @@ -3691,6 +3697,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), "quote" => Ok(GeneratedField::Quote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), "escape" => Ok(GeneratedField::Escape), "comment" => Ok(GeneratedField::Comment), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -3716,6 +3723,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { let mut has_header__ = None; let mut delimiter__ = None; let mut quote__ = None; + let mut newlines_in_values__ = None; let mut optional_escape__ = None; let mut optional_comment__ = None; while let Some(k) = map_.next_key()? { @@ -3744,6 +3752,12 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } quote__ = Some(map_.next_value()?); } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = Some(map_.next_value()?); + } GeneratedField::Escape => { if optional_escape__.is_some() { return Err(serde::de::Error::duplicate_field("escape")); @@ -3763,6 +3777,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), quote: quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), optional_escape: optional_escape__, optional_comment: optional_comment__, }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 605c56fa946a..98b70dc25351 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1542,6 +1542,8 @@ pub struct CsvScanExecNode { pub delimiter: ::prost::alloc::string::String, #[prost(string, tag = "4")] pub quote: ::prost::alloc::string::String, + #[prost(bool, tag = "7")] + pub newlines_in_values: bool, #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] pub optional_escape: ::core::option::Option, #[prost(oneof = "csv_scan_exec_node::OptionalComment", tags = "6")] @@ -1939,12 +1941,11 @@ pub struct PartitionStats { #[repr(i32)] pub enum AggregateFunction { Min = 0, - Max = 1, /// SUM = 2; /// AVG = 3; /// COUNT = 4; /// APPROX_DISTINCT = 5; - /// + /// ARRAY_AGG = 6; /// VARIANCE = 7; /// VARIANCE_POP = 8; /// COVARIANCE = 9; @@ -1973,7 +1974,7 @@ pub enum AggregateFunction { /// REGR_SXY = 34; /// STRING_AGG = 35; /// NTH_VALUE_AGG = 36; - ArrayAgg = 6, + Max = 1, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -1984,7 +1985,6 @@ impl AggregateFunction { match self { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", - AggregateFunction::ArrayAgg => "ARRAY_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1992,7 +1992,6 @@ impl AggregateFunction { match value { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), - "ARRAY_AGG" => Some(Self::ArrayAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 09e36a650b9f..2c4085b88869 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -18,19 +18,129 @@ use std::sync::Arc; use datafusion::{ + config::CsvOptions, datasource::file_format::{ arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, parquet::ParquetFormatFactory, FileFormatFactory, }, prelude::SessionContext, }; -use datafusion_common::{not_impl_err, TableReference}; +use datafusion_common::{ + exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError, + TableReference, +}; +use prost::Message; + +use crate::protobuf::CsvOptions as CsvOptionsProto; use super::LogicalExtensionCodec; #[derive(Debug)] pub struct CsvLogicalExtensionCodec; +impl CsvOptionsProto { + fn from_factory(factory: &CsvFormatFactory) -> Self { + if let Some(options) = &factory.options { + CsvOptionsProto { + has_header: options.has_header.map_or(vec![], |v| vec![v as u8]), + delimiter: vec![options.delimiter], + quote: vec![options.quote], + escape: options.escape.map_or(vec![], |v| vec![v]), + double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]), + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + date_format: options.date_format.clone().unwrap_or_default(), + datetime_format: options.datetime_format.clone().unwrap_or_default(), + timestamp_format: options.timestamp_format.clone().unwrap_or_default(), + timestamp_tz_format: options + .timestamp_tz_format + .clone() + .unwrap_or_default(), + time_format: options.time_format.clone().unwrap_or_default(), + null_value: options.null_value.clone().unwrap_or_default(), + comment: options.comment.map_or(vec![], |v| vec![v]), + newlines_in_values: options + .newlines_in_values + .map_or(vec![], |v| vec![v as u8]), + } + } else { + CsvOptionsProto::default() + } + } +} + +impl From<&CsvOptionsProto> for CsvOptions { + fn from(proto: &CsvOptionsProto) -> Self { + CsvOptions { + has_header: if !proto.has_header.is_empty() { + Some(proto.has_header[0] != 0) + } else { + None + }, + delimiter: proto.delimiter.first().copied().unwrap_or(b','), + quote: proto.quote.first().copied().unwrap_or(b'"'), + escape: if !proto.escape.is_empty() { + Some(proto.escape[0]) + } else { + None + }, + double_quote: if !proto.double_quote.is_empty() { + Some(proto.double_quote[0] != 0) + } else { + None + }, + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + date_format: if proto.date_format.is_empty() { + None + } else { + Some(proto.date_format.clone()) + }, + datetime_format: if proto.datetime_format.is_empty() { + None + } else { + Some(proto.datetime_format.clone()) + }, + timestamp_format: if proto.timestamp_format.is_empty() { + None + } else { + Some(proto.timestamp_format.clone()) + }, + timestamp_tz_format: if proto.timestamp_tz_format.is_empty() { + None + } else { + Some(proto.timestamp_tz_format.clone()) + }, + time_format: if proto.time_format.is_empty() { + None + } else { + Some(proto.time_format.clone()) + }, + null_value: if proto.null_value.is_empty() { + None + } else { + Some(proto.null_value.clone()) + }, + comment: if !proto.comment.is_empty() { + Some(proto.comment[0]) + } else { + None + }, + newlines_in_values: if proto.newlines_in_values.is_empty() { + None + } else { + Some(proto.newlines_in_values[0] != 0) + }, + } + } +} + // TODO! This is a placeholder for now and needs to be implemented for real. impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode( @@ -73,17 +183,41 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { fn try_decode_file_format( &self, - __buf: &[u8], - __ctx: &SessionContext, + buf: &[u8], + _ctx: &SessionContext, ) -> datafusion_common::Result> { - Ok(Arc::new(CsvFormatFactory::new())) + let proto = CsvOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode CsvOptionsProto: {:?}", + e + )) + })?; + let options: CsvOptions = (&proto).into(); + Ok(Arc::new(CsvFormatFactory { + options: Some(options), + })) } fn try_encode_file_format( &self, - __buf: &[u8], - __node: Arc, + buf: &mut Vec, + node: Arc, ) -> datafusion_common::Result<()> { + let options = + if let Some(csv_factory) = node.as_any().downcast_ref::() { + csv_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); + }; + + let proto = CsvOptionsProto::from_factory(&CsvFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + })?; + Ok(()) } } @@ -141,7 +275,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -201,7 +335,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -261,7 +395,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) @@ -321,7 +455,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { fn try_encode_file_format( &self, - __buf: &[u8], + __buf: &mut Vec, __node: Arc, ) -> datafusion_common::Result<()> { Ok(()) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6b556a8ed6b..aea8e454a31c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -144,7 +144,6 @@ impl From for AggregateFunction { match agg_fun { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, - protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, } } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 2a963fb13ccf..5427f34e8e07 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -131,7 +131,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_file_format( &self, - _buf: &[u8], + _buf: &mut Vec, _node: Arc, ) -> Result<()> { Ok(()) @@ -1666,10 +1666,9 @@ impl AsLogicalPlan for LogicalPlanNode { input, extension_codec, )?; - - let buf = Vec::new(); + let mut buf = Vec::new(); extension_codec - .try_encode_file_format(&buf, file_type_to_format(file_type)?)?; + .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9607b918eb89..c2441892e8a8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { match value { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, - AggregateFunction::ArrayAgg => Self::ArrayAgg, } } } @@ -386,7 +385,6 @@ pub fn serialize_expr( }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, }; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1220f42ded83..8c9e5bbd0e95 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -211,6 +211,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } else { None }, + scan.newlines_in_values, FileCompressionType::UNCOMPRESSED, ))), #[cfg(feature = "parquet")] @@ -506,7 +507,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct, false) } } }).transpose()?.ok_or_else(|| { @@ -1579,6 +1580,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } else { None }, + newlines_in_values: exec.newlines_in_values(), }, )), }); diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 7ea2902cf3c0..140482b9903c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, - WindowShift, + BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, Rank, + RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -260,16 +259,9 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - let mut distinct = false; - - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { + + // TODO: remove Min and Max + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Min } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Max @@ -277,7 +269,10 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; - Ok(AggrFn { inner, distinct }) + Ok(AggrFn { + inner, + distinct: false, + }) } pub fn serialize_physical_sort_exprs( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 0117502f400d..e17515086ecd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,25 +15,25 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; - use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DECIMAL256_MAX_PRECISION, }; use prost::Message; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::vec; use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; -use datafusion::datasource::file_format::format_as_file_type; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::session_state::SessionStateBuilder; @@ -44,6 +44,7 @@ use datafusion::functions_aggregate::expr_fn::{ count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_array::map::map; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -66,7 +67,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -378,7 +379,9 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; - let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new())); + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format), + )); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -393,7 +396,6 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); @@ -456,7 +458,9 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); - let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -477,6 +481,27 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!("test.csv", copy_to.output_url); assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let csv_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv_config = csv_factory.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, csv_config.delimiter); + assert_eq!(csv_format.date_format, csv_config.date_format); + assert_eq!(csv_format.datetime_format, csv_config.datetime_format); + assert_eq!(csv_format.timestamp_format, csv_config.timestamp_format); + assert_eq!(csv_format.time_format, csv_config.time_format); + assert_eq!(csv_format.null_value, csv_config.null_value) } _ => panic!(), } @@ -702,6 +727,12 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), + array_agg(lit(1)), + array_agg(lit(1)).distinct().build().unwrap(), + map( + vec![lit(1), lit(2), lit(3)], + vec![lit(10), lit(20), lit(30)], + ), ]; // ensure expressions created with the expr api can be round tripped @@ -1372,6 +1403,7 @@ fn round_trip_datatype() { DataType::Utf8, DataType::LargeUtf8, DataType::Decimal128(7, 12), + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), // Recursive list tests DataType::List(new_arc_field("Level1", DataType::Binary, true)), DataType::List(new_arc_field( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index fba6dfe42599..31ed0837d2f5 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -301,6 +301,7 @@ fn roundtrip_window() -> Result<()> { "avg(b)", false, false, + false, )?, &[], &[], @@ -324,6 +325,7 @@ fn roundtrip_window() -> Result<()> { "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", false, false, + false, )?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( @@ -367,6 +369,7 @@ fn rountrip_aggregate() -> Result<()> { "AVG(b)", false, false, + false, )?], // NTH_VALUE vec![create_aggregate_expr( @@ -379,6 +382,7 @@ fn rountrip_aggregate() -> Result<()> { "NTH_VALUE(b, 1)", false, false, + false, )?], // STRING_AGG vec![create_aggregate_expr( @@ -394,6 +398,7 @@ fn rountrip_aggregate() -> Result<()> { "STRING_AGG(name, ',')", false, false, + false, )?], ]; @@ -431,6 +436,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { "AVG(b)", false, false, + false, )?]; let agg = AggregateExec::try_new( @@ -502,6 +508,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { "example_agg", false, false, + false, )?]; roundtrip_test_with_context( @@ -1000,6 +1007,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "aggregate_udf", false, false, + false, )?; let filter = Arc::new(FilterExec::try_new( @@ -1032,6 +1040,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "aggregate_udf", true, true, + false, )?; let aggregate = Arc::new(AggregateExec::try_new( diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index b724afabaf09..d9ee1b4db8e2 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::{collections::HashMap, sync::Arc}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ @@ -29,7 +34,6 @@ use datafusion_sql::{ sqlparser::{dialect::GenericDialect, parser::Parser}, TableReference, }; -use std::{collections::HashMap, sync::Arc}; fn main() { let sql = "SELECT \ @@ -53,7 +57,8 @@ fn main() { // create a logical query plan let context_provider = MyContextProvider::new() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -65,6 +70,7 @@ struct MyContextProvider { options: ConfigOptions, tables: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MyContextProvider { @@ -73,6 +79,11 @@ impl MyContextProvider { self } + fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } + fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -105,6 +116,7 @@ impl MyContextProvider { tables, options: Default::default(), udafs: Default::default(), + expr_planners: vec![], } } } @@ -154,4 +166,8 @@ impl ContextProvider for MyContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index dab328cc4908..4804752d8389 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,6 +21,7 @@ use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -66,7 +67,7 @@ pub fn suggest_valid_function( find_closest_match(valid_funcs, input_function_name) } -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitive) /// Input `candidates` must not be empty otherwise it will panic fn find_closest_match(candidates: Vec, target: &str) -> String { let target = target.to_lowercase(); @@ -227,6 +228,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { crate::utils::normalize_ident(name.0[0].clone()) }; + if name.eq("make_map") { + let mut fn_args = + self.function_args_to_expr(args.clone(), schema, planner_context)?; + for planner in self.context_provider.get_expr_planners().iter() { + match planner.plan_make_map(fn_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => fn_args = args, + } + } + } + // user-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 39736b1fbba5..f8979bde3086 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; +use sqlparser::ast::{Expr as SQLExpr, Ident}; + use datafusion_common::{ internal_err, not_impl_err, plan_datafusion_err, Column, DFSchema, DataFusionError, - Result, ScalarValue, TableReference, + Result, TableReference, }; -use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; -use sqlparser::ast::{Expr as SQLExpr, Ident}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::{Case, Expr}; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -125,26 +128,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match search_result { // found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support multiple nested identifiers - if nested_names.len() > 1 { - return not_impl_err!( - "Nested identifiers not yet supported for column {}", - Column::from((qualifier, field)).quoted_flat_name() - ); - } - let nested_name = nested_names[0].to_string(); - - let col = Expr::Column(Column::from((qualifier, field))); - if let Some(udf) = - self.context_provider.get_function_meta("get_field") - { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - udf, - vec![col, lit(ScalarValue::from(nested_name))], - ))) - } else { - internal_err!("get_field not found") + // found matching field with spare identifier(s) for nested field(s) in structure + for planner in self.context_provider.get_expr_planners() { + if let Ok(planner_result) = planner.plan_compound_identifier( + field, + qualifier, + nested_names, + ) { + match planner_result { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(_args) => {} + } + } } + not_impl_err!( + "Compound identifiers not supported by ExprPlanner: {ids:?}" + ) } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bc13484235c3..a743aa72829d 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -218,7 +218,7 @@ impl fmt::Display for CreateExternalTable { /// /// This can either be a [`Statement`] from [`sqlparser`] from a /// standard SQL dialect, or a DataFusion extension such as `CREATE -/// EXTERAL TABLE`. See [`DFParser`] for more information. +/// EXTERNAL TABLE`. See [`DFParser`] for more information. /// /// [`Statement`]: sqlparser::ast::Statement #[derive(Debug, Clone, PartialEq, Eq)] @@ -1101,7 +1101,7 @@ mod tests { }); expect_parse_ok(sql, expected)?; - // positive case: column definiton allowed in 'partition by' clause + // positive case: column definition allowed in 'partition by' clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; let expected = Statement::CreateExternalTable(CreateExternalTable { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9380e569f2e4..b812dae5ae3e 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -105,7 +105,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Unnest table factor has empty input let schema = DFSchema::empty(); let input = LogicalPlanBuilder::empty(true).build()?; - // Unnest table factor can have multiple arugments. + // Unnest table factor can have multiple arguments. // We treat each argument as a separate unnest expression. let unnest_exprs = array_exprs .into_iter() diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 84b80c311245..fc46c3a841b5 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -306,7 +306,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut intermediate_select_exprs = select_exprs; // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration - // Ony exaust the loop if no more unnest transformation is found + // Only exaust the loop if no more unnest transformation is found for i in 0.. { let mut unnest_columns = vec![]; // from which column used for projection, before the unnest happen diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 6df25086305d..8eb4113f80a6 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1146,6 +1146,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // we could introduce alias in OptionDefinition if this string matching thing grows format!("{base_query} WHERE name = 'datafusion.execution.time_zone'") } else { + // These values are what are used to make the information_schema table, so we just + // check here, before actually planning or executing the query, if it would produce no + // results, and error preemptively if it would (for a better UX) + let is_valid_variable = self + .context_provider + .options() + .entries() + .iter() + .any(|opt| opt.key == variable); + + if !is_valid_variable { + return plan_err!( + "'{variable}' is not a variable which can be viewed with 'SHOW'" + ); + } + format!("{base_query} WHERE name = '{variable}'") }; diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 06b4d4a710a3..02eb44dbb657 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -497,7 +497,7 @@ impl Default for DerivedRelationBuilder { pub(super) struct UninitializedFieldError(&'static str); impl UninitializedFieldError { - /// Create a new `UnitializedFieldError` for the specified field name. + /// Create a new `UninitializedFieldError` for the specified field name. pub fn new(field_name: &'static str) -> Self { UninitializedFieldError(field_name) } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index eca2eb4fd0ec..7eca326386fc 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -15,8 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use arrow_schema::TimeUnit; use regex::Regex; -use sqlparser::keywords::ALL_KEYWORDS; +use sqlparser::{ + ast::{self, Ident, ObjectName, TimezoneInfo}, + keywords::ALL_KEYWORDS, +}; /// `Dialect` to use for Unparsing /// @@ -27,7 +33,7 @@ use sqlparser::keywords::ALL_KEYWORDS; /// /// See /// See also the discussion in -pub trait Dialect { +pub trait Dialect: Send + Sync { /// Return the character used to quote identifiers. fn identifier_quote_style(&self, _identifier: &str) -> Option; @@ -36,8 +42,8 @@ pub trait Dialect { true } - // Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? - // E.g. Trino, Athena and Dremio does not have DATETIME data type + /// Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? + /// E.g. Trino, Athena and Dremio does not have DATETIME data type fn use_timestamp_for_date64(&self) -> bool { false } @@ -45,6 +51,51 @@ pub trait Dialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? + /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::Double + } + + /// The SQL type to use for Arrow Utf8 unparsing + /// Most dialects use VARCHAR, but some, like MySQL, require CHAR + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Varchar(None) + } + + /// The SQL type to use for Arrow LargeUtf8 unparsing + /// Most dialects use TEXT, but some, like MySQL, require CHAR + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text + } + + /// The date field extract style to use: `DateFieldExtractStyle` + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::DatePart + } + + /// The SQL type to use for Arrow Int64 unparsing + /// Most dialects use BigInt, but some, like MySQL, require SIGNED + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::BigInt(None) + } + + /// The SQL type to use for Timestamp unparsing + /// Most dialects use Timestamp, but some, like MySQL, require Datetime + /// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + let tz_info = match tz { + Some(_) => TimezoneInfo::WithTimeZone, + None => TimezoneInfo::None, + }; + + ast::DataType::Timestamp(None, tz_info) + } } /// `IntervalStyle` to use for unparsing @@ -62,6 +113,19 @@ pub enum IntervalStyle { MySQL, } +/// Datetime subfield extraction style for unparsing +/// +/// `` +/// Different DBMSs follow different standards; popular ones are: +/// date_part('YEAR', date '2001-02-16') +/// EXTRACT(YEAR from date '2001-02-16') +/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT. +#[derive(Clone, Copy, PartialEq)] +pub enum DateFieldExtractStyle { + DatePart, + Extract, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -87,6 +151,10 @@ impl Dialect for PostgreSqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::DoublePrecision + } } pub struct MySqlDialect {} @@ -103,6 +171,30 @@ impl Dialect for MySqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::MySQL } + + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Extract + } + + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + _tz: &Option>, + ) -> ast::DataType { + ast::DataType::Datetime(None) + } } pub struct SqliteDialect {} @@ -118,6 +210,13 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, } impl Default for CustomDialect { @@ -127,6 +226,16 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, + float64_ast_dtype: sqlparser::ast::DataType::Double, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), } } } @@ -158,6 +267,38 @@ impl Dialect for CustomDialect { fn interval_style(&self) -> IntervalStyle { self.interval_style } + + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + self.float64_ast_dtype.clone() + } + + fn utf8_cast_dtype(&self) -> ast::DataType { + self.utf8_cast_dtype.clone() + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + self.large_utf8_cast_dtype.clone() + } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + self.date_field_extract_style + } + + fn int64_cast_dtype(&self) -> ast::DataType { + self.int64_cast_dtype.clone() + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + if tz.is_some() { + self.timestamp_tz_cast_dtype.clone() + } else { + self.timestamp_cast_dtype.clone() + } + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -179,6 +320,13 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, } impl Default for CustomDialectBuilder { @@ -194,6 +342,16 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, + float64_ast_dtype: sqlparser::ast::DataType::Double, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), } } @@ -203,6 +361,13 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, + float64_ast_dtype: self.float64_ast_dtype, + utf8_cast_dtype: self.utf8_cast_dtype, + large_utf8_cast_dtype: self.large_utf8_cast_dtype, + date_field_extract_style: self.date_field_extract_style, + int64_cast_dtype: self.int64_cast_dtype, + timestamp_cast_dtype: self.timestamp_cast_dtype, + timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, } } @@ -235,4 +400,54 @@ impl CustomDialectBuilder { self.interval_style = interval_style; self } + + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. + pub fn with_float64_ast_dtype( + mut self, + float64_ast_dtype: sqlparser::ast::DataType, + ) -> Self { + self.float64_ast_dtype = float64_ast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Utf8 casting: VARCHAR, CHAR, etc. + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { + self.utf8_cast_dtype = utf8_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for LargeUtf8 casting: TEXT, CHAR, etc. + pub fn with_large_utf8_cast_dtype( + mut self, + large_utf8_cast_dtype: ast::DataType, + ) -> Self { + self.large_utf8_cast_dtype = large_utf8_cast_dtype; + self + } + + /// Customize the dialect with a specific date field extract style listed in `DateFieldExtractStyle` + pub fn with_date_field_extract_style( + mut self, + date_field_extract_style: DateFieldExtractStyle, + ) -> Self { + self.date_field_extract_style = date_field_extract_style; + self + } + + /// Customize the dialect with a specific SQL type for Int64 casting: BigInt, SIGNED, etc. + pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self { + self.int64_cast_dtype = int64_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc. + pub fn with_timestamp_cast_dtype( + mut self, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, + ) -> Self { + self.timestamp_cast_dtype = timestamp_cast_dtype; + self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index e6b67b5d9fb2..f4ea44f37d78 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -16,6 +16,13 @@ // under the License. use core::fmt; + +use datafusion_expr::ScalarUDF; +use sqlparser::ast::Value::SingleQuotedString; +use sqlparser::ast::{ + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, +}; use std::sync::Arc; use std::{fmt::Display, vec}; @@ -28,12 +35,6 @@ use arrow_array::types::{ }; use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; -use sqlparser::ast::Value::SingleQuotedString; -use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, - TimezoneInfo, UnaryOperator, -}; - use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, @@ -43,7 +44,7 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -use super::dialect::IntervalStyle; +use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; /// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` @@ -149,6 +150,12 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); + if let Some(expr) = + self.scalar_function_to_sql_overrides(func_name, func, args) + { + return Ok(expr); + } + let args = args .iter() .map(|e| { @@ -545,6 +552,38 @@ impl Unparser<'_> { } } + fn scalar_function_to_sql_overrides( + &self, + func_name: &str, + _func: &Arc, + args: &[Expr], + ) -> Option { + if func_name.to_lowercase() == "date_part" + && self.dialect.date_field_extract_style() == DateFieldExtractStyle::Extract + && args.len() == 2 + { + let date_expr = self.expr_to_sql(&args[1]).ok()?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return None, + }; + + return Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + }); + } + } + None + } + fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { ast::DataType::Timestamp(None, ast::TimezoneInfo::None) @@ -1105,6 +1144,131 @@ impl Unparser<'_> { } } + /// MySQL requires INTERVAL sql to be in the format: INTERVAL 1 YEAR + INTERVAL 1 MONTH + INTERVAL 1 DAY etc + /// `` + /// Interval sequence can't be wrapped in brackets - (INTERVAL 1 YEAR + INTERVAL 1 MONTH ...) so we need to generate + /// a single INTERVAL expression so it works correct for interval substraction cases + /// MySQL supports the DAY_MICROSECOND unit type (format is DAYS HOURS:MINUTES:SECONDS.MICROSECONDS), but it is not supported by sqlparser + /// so we calculate the best single interval to represent the provided duration + fn interval_to_mysql_expr( + &self, + months: i32, + days: i32, + microseconds: i64, + ) -> Result { + // MONTH only + if months != 0 && days == 0 && microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + months.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } else if months != 0 { + return not_impl_err!("Unsupported Interval scalar with both Month and DayTime for IntervalStyle::MySQL"); + } + + // DAY only + if microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + // calculate the best single interval to represent the provided days and microseconds + + let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 1_000_000); + + if microseconds % 1_000_000 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + microseconds.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Microsecond), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let secs = microseconds / 1_000_000; + + if secs % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + secs.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Second), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let mins = secs / 60; + + if mins % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + mins.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Minute), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let hours = mins / 60; + + if hours % 24 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + hours.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Hour), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let days = hours / 24; + + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + fn interval_scalar_to_sql(&self, v: &ScalarValue) -> Result { match self.dialect.interval_style() { IntervalStyle::PostgresVerbose => { @@ -1127,10 +1291,7 @@ impl Unparser<'_> { } // If the interval standard is SQLStandard, implement a simple unparse logic IntervalStyle::SQLStandard => match v { - ScalarValue::IntervalYearMonth(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; + ScalarValue::IntervalYearMonth(Some(v)) => { let interval = Interval { value: Box::new(ast::Expr::Value( ast::Value::SingleQuotedString(v.to_string()), @@ -1142,10 +1303,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Interval(interval)) } - ScalarValue::IntervalDayTime(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; + ScalarValue::IntervalDayTime(Some(v)) => { let days = v.days; let secs = v.milliseconds / 1_000; let mins = secs / 60; @@ -1168,11 +1326,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Interval(interval)) } - ScalarValue::IntervalMonthDayNano(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; - + ScalarValue::IntervalMonthDayNano(Some(v)) => { if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { let interval = Interval { value: Box::new(ast::Expr::Value( @@ -1184,10 +1338,7 @@ impl Unparser<'_> { fractional_seconds_precision: None, }; Ok(ast::Expr::Interval(interval)) - } else if v.months == 0 - && v.days >= 0 - && v.nanoseconds % 1_000_000 == 0 - { + } else if v.months == 0 && v.nanoseconds % 1_000_000 == 0 { let days = v.days; let secs = v.nanoseconds / 1_000_000_000; let mins = secs / 60; @@ -1214,11 +1365,29 @@ impl Unparser<'_> { not_impl_err!("Unsupported IntervalMonthDayNano scalar with both Month and DayTime for IntervalStyle::SQLStandard") } } - _ => Ok(ast::Expr::Value(ast::Value::Null)), + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), + }, + IntervalStyle::MySQL => match v { + ScalarValue::IntervalYearMonth(Some(v)) => { + self.interval_to_mysql_expr(*v, 0, 0) + } + ScalarValue::IntervalDayTime(Some(v)) => { + self.interval_to_mysql_expr(0, v.days, v.milliseconds as i64 * 1_000) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + if v.nanoseconds % 1_000 != 0 { + return not_impl_err!( + "Unsupported IntervalMonthDayNano scalar with nanoseconds precision for IntervalStyle::MySQL" + ); + } + self.interval_to_mysql_expr(v.months, v.days, v.nanoseconds / 1_000) + } + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), }, - IntervalStyle::MySQL => { - not_impl_err!("Unsupported interval scalar for IntervalStyle::MySQL") - } } } @@ -1231,7 +1400,7 @@ impl Unparser<'_> { DataType::Int8 => Ok(ast::DataType::TinyInt(None)), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), DataType::Int32 => Ok(ast::DataType::Integer(None)), - DataType::Int64 => Ok(ast::DataType::BigInt(None)), + DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), @@ -1240,14 +1409,9 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } DataType::Float32 => Ok(ast::DataType::Float(None)), - DataType::Float64 => Ok(ast::DataType::Double), - DataType::Timestamp(_, tz) => { - let tz_info = match tz { - Some(_) => TimezoneInfo::WithTimeZone, - None => TimezoneInfo::None, - }; - - Ok(ast::DataType::Timestamp(None, tz_info)) + DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), + DataType::Timestamp(time_unit, tz) => { + Ok(self.dialect.timestamp_cast_dtype(time_unit, tz)) } DataType::Date32 => Ok(ast::DataType::Date), DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()), @@ -1275,8 +1439,8 @@ impl Unparser<'_> { DataType::BinaryView => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(ast::DataType::Varchar(None)), - DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), + DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1335,6 +1499,7 @@ mod tests { use arrow::datatypes::TimeUnit; use arrow::datatypes::{Field, Schema}; use arrow_schema::DataType::Int8; + use ast::ObjectName; use datafusion_common::TableReference; use datafusion_expr::{ case, col, cube, exists, grouping_set, interval_datetime_lit, @@ -1822,6 +1987,34 @@ mod tests { Ok(()) } + #[test] + fn custom_dialect_float64_ast_dtype() -> Result<()> { + for (float64_ast_dtype, identifier) in [ + (sqlparser::ast::DataType::Double, "DOUBLE"), + ( + sqlparser::ast::DataType::DoublePrecision, + "DOUBLE PRECISION", + ), + ] { + let dialect = CustomDialectBuilder::new() + .with_float64_ast_dtype(float64_ast_dtype) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { let tests: Vec<(Expr, &str, bool)> = vec![ @@ -1857,6 +2050,11 @@ mod tests { IntervalStyle::SQLStandard, "INTERVAL '1 12:0:0.000' DAY TO SECOND", ), + ( + interval_month_day_nano_lit("-1.5 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '-1 -12:0:0.000' DAY TO SECOND", + ), ( interval_month_day_nano_lit("1.51234 DAY"), IntervalStyle::SQLStandard, @@ -1921,6 +2119,46 @@ mod tests { IntervalStyle::PostgresVerbose, r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 13 MONTH"#, + ), + ( + interval_month_day_nano_lit("1 YEAR -1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 11 MONTH"#, + ), + ( + interval_month_day_nano_lit("15 DAY"), + IntervalStyle::MySQL, + r#"INTERVAL 15 DAY"#, + ), + ( + interval_month_day_nano_lit("-40 HOURS"), + IntervalStyle::MySQL, + r#"INTERVAL -40 HOUR"#, + ), + ( + interval_datetime_lit("-1.5 DAY 1 HOUR"), + IntervalStyle::MySQL, + "INTERVAL -35 HOUR", + ), + ( + interval_datetime_lit("1000000 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL 86400006020 SECOND"#, + ), + ( + interval_year_month_lit("0 DAY 0 HOUR"), + IntervalStyle::MySQL, + r#"INTERVAL 0 DAY"#, + ), + ( + interval_month_day_nano_lit("-1296000000 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL -15000 DAY"#, + ), ]; for (value, style, expected) in tests { @@ -1936,4 +2174,149 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { + let default_dialect = CustomDialectBuilder::default().build(); + let mysql_custom_dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .with_large_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Utf8, "VARCHAR"), + (&default_dialect, DataType::LargeUtf8, "TEXT"), + (&mysql_custom_dialect, DataType::Utf8, "CHAR"), + (&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"), + ] { + let unparser = Unparser::new(dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_date_field_extract_style() -> Result<()> { + for (extract_style, unit, expected) in [ + ( + DateFieldExtractStyle::DatePart, + "YEAR", + "date_part('YEAR', x)", + ), + ( + DateFieldExtractStyle::Extract, + "YEAR", + "EXTRACT(YEAR FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "MONTH", + "date_part('MONTH', x)", + ), + ( + DateFieldExtractStyle::Extract, + "MONTH", + "EXTRACT(MONTH FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "DAY", + "date_part('DAY', x)", + ), + (DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"), + ] { + let dialect = CustomDialectBuilder::new() + .with_date_field_extract_style(extract_style) + .build(); + + let unparser = Unparser::new(&dialect); + let expr = ScalarUDF::new_from_impl( + datafusion_functions::datetime::date_part::DatePartFunc::new(), + ) + .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_int64_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int64_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_teimstamp_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_timestamp_cast_dtype( + ast::DataType::Datetime(None), + ast::DataType::Datetime(None), + ) + .build(); + + let timestamp = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_with_tz = + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())); + + for (dialect, data_type, identifier) in [ + (&default_dialect, ×tamp, "TIMESTAMP"), + ( + &default_dialect, + ×tamp_with_tz, + "TIMESTAMP WITH TIME ZONE", + ), + (&mysql_dialect, ×tamp, "DATETIME"), + (&mysql_dialect, ×tamp_with_tz, "DATETIME"), + ] { + let unparser = Unparser::new(dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: data_type.clone(), + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 7a653f80be08..59660f4f0404 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -19,7 +19,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, R use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, }; -use sqlparser::ast::{self, SetExpr}; +use sqlparser::ast::{self, Ident, SetExpr}; use crate::unparser::utils::unproject_agg_exprs; @@ -29,6 +29,7 @@ use super::{ SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, rewrite::normalize_union_schema, + rewrite::rewrite_plan_for_sort_on_non_projected_fields, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -199,33 +200,21 @@ impl Unparser<'_> { Ok(()) } - fn projection_to_sql( - &self, - plan: &LogicalPlan, - p: &Projection, - query: &mut Option, - select: &mut SelectBuilder, - relation: &mut RelationBuilder, - ) -> Result<()> { - // A second projection implies a derived tablefactor - if !select.already_projected() { - self.reconstruct_select_statement(plan, p, select)?; - self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) - } else { - let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ - let inner_statment = self.plan_to_sql(plan)?; - if let ast::Statement::Query(inner_query) = inner_statment { - inner_query - } else { - return internal_err!( - "Subquery must be a Query, but found {inner_statment:?}" - ); - } - }); - relation.derived(derived_builder); - Ok(()) - } + fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> { + let mut derived_builder = DerivedRelationBuilder::default(); + derived_builder.lateral(false).alias(None).subquery({ + let inner_statement = self.plan_to_sql(plan)?; + if let ast::Statement::Query(inner_query) = inner_statement { + inner_query + } else { + return internal_err!( + "Subquery must be a Query, but found {inner_statement:?}" + ); + } + }); + relation.derived(derived_builder); + + Ok(()) } fn select_to_sql_recursively( @@ -256,7 +245,17 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Projection(p) => { - self.projection_to_sql(plan, p, query, select, relation) + if let Some(new_plan) = rewrite_plan_for_sort_on_non_projected_fields(p) { + return self + .select_to_sql_recursively(&new_plan, query, select, relation); + } + + // Projection can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } + self.reconstruct_select_statement(plan, p, select)?; + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { if let Some(AggVariant::Aggregate(agg)) = @@ -278,6 +277,10 @@ impl Unparser<'_> { ) } LogicalPlan::Limit(limit) => { + // Limit can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } if let Some(fetch) = limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( @@ -298,6 +301,10 @@ impl Unparser<'_> { ) } LogicalPlan::Sort(sort) => { + // Sort can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } if let Some(query_ref) = query { query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); } else { @@ -323,6 +330,10 @@ impl Unparser<'_> { ) } LogicalPlan::Distinct(distinct) => { + // Distinct can be top-level plan for derived table + if select.already_projected() { + return self.derive(plan, relation); + } let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), Distinct::On(on) => { @@ -457,15 +468,11 @@ impl Unparser<'_> { } LogicalPlan::SubqueryAlias(plan_alias) => { // Handle bottom-up to allocate relation - self.select_to_sql_recursively( - plan_alias.input.as_ref(), - query, - select, - relation, - )?; + let (plan, columns) = subquery_alias_inner_query_and_columns(plan_alias); + self.select_to_sql_recursively(plan, query, select, relation)?; relation.alias(Some( - self.new_table_alias(plan_alias.alias.table().to_string()), + self.new_table_alias(plan_alias.alias.table().to_string(), columns), )); Ok(()) @@ -599,10 +606,10 @@ impl Unparser<'_> { self.binary_op_to_sql(lhs, rhs, ast::BinaryOperator::And) } - fn new_table_alias(&self, alias: String) -> ast::TableAlias { + fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { ast::TableAlias { name: self.new_ident_quoted_if_needs(alias), - columns: Vec::new(), + columns, } } @@ -611,6 +618,67 @@ impl Unparser<'_> { } } +// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +// subquery +// - `(SELECT column_a as a from table) AS A` +// - `(SELECT column_a from table) AS A (a)` +// +// A roundtrip example for table alias with columns +// +// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +// +// LogicPlan: +// Projection: c.id +// SubqueryAlias: c +// Projection: j1.j1_id AS id +// Projection: j1.j1_id +// TableScan: j1 +// +// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +// `(SELECT j1.j1_id FROM j1)` +// +// With this logic, the unparsed query will be: +// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +// +// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +// Column in the Projections. Once the parser side is fixed, this logic should work +fn subquery_alias_inner_query_and_columns( + subquery_alias: &datafusion_expr::SubqueryAlias, +) -> (&LogicalPlan, Vec) { + let plan: &LogicalPlan = subquery_alias.input.as_ref(); + + let LogicalPlan::Projection(outer_projections) = plan else { + return (plan, vec![]); + }; + + // check if it's projection inside projection + let LogicalPlan::Projection(inner_projection) = outer_projections.input.as_ref() + else { + return (plan, vec![]); + }; + + let mut columns: Vec = vec![]; + // check if the inner projection and outer projection have a matching pattern like + // Projection: j1.j1_id AS id + // Projection: j1.j1_id + for (i, inner_expr) in inner_projection.expr.iter().enumerate() { + let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + return (plan, vec![]); + }; + + if outer_alias.expr.as_ref() != inner_expr { + return (plan, vec![]); + }; + + columns.push(outer_alias.name.as_str().into()); + } + + (outer_projections.input.as_ref(), columns) +} + impl From for DataFusionError { fn from(e: BuilderError) -> Self { DataFusionError::External(Box::new(e)) diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index a73fce30ced3..fba95ad48f32 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, Result, }; -use datafusion_expr::{Expr, LogicalPlan, Sort}; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. /// @@ -99,3 +102,76 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { Ok(sort_exprs) } + +// Rewrite logic plan for query that order by columns are not in projections +// Plan before rewrite: +// +// Projection: j1.j1_string, j2.j2_string +// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id +// Inner Join: Filter: j1.j1_id = j2.j2_id +// TableScan: j1 +// TableScan: j2 +// +// Plan after rewrite +// +// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +// Projection: j1.j1_string, j2.j2_string +// Inner Join: Filter: j1.j1_id = j2.j2_id +// TableScan: j1 +// TableScan: j2 +// +// This prevents the original plan generate query with derived table but missing alias. +pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( + p: &Projection, +) -> Option { + let LogicalPlan::Sort(sort) = p.input.as_ref() else { + return None; + }; + + let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else { + return None; + }; + + let mut map = HashMap::new(); + let inner_exprs = inner_p + .expr + .iter() + .map(|f| { + if let Expr::Alias(alias) = f { + let a = Expr::Column(alias.name.clone().into()); + map.insert(a.clone(), f.clone()); + a + } else { + f.clone() + } + }) + .collect::>(); + + let mut collects = p.expr.clone(); + for expr in &sort.expr { + if let Expr::Sort(s) = expr { + collects.push(s.expr.as_ref().clone()); + } + } + + if collects.iter().collect::>() + == inner_exprs.iter().collect::>() + { + let mut sort = sort.clone(); + let mut inner_p = inner_p.clone(); + + let new_exprs = p + .expr + .iter() + .map(|e| map.get(e).unwrap_or(e).clone()) + .collect::>(); + + inner_p.expr.clone_from(&new_exprs); + sort.input = Arc::new(LogicalPlan::Projection(inner_p)); + + Some(LogicalPlan::Sort(sort)) + } else { + None + } +} diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 331da9773f16..71f64f1cf459 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -31,7 +31,7 @@ pub(crate) enum AggVariant<'a> { /// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). /// If an Aggregate or window node is not found prior to this or at all before reaching the end -/// of the tree, None is returned. It is assumed that a Window and Aggegate node cannot both +/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both /// be found in a single select query. pub(crate) fn find_agg_node_within_select<'a>( plan: &'a LogicalPlan, @@ -82,7 +82,7 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result expr.clone() .transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schmea + // find the column in the agg schema if let Ok(n) = agg.schema.index_of_column(&c) { let unprojected_expr = agg .group_expr diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 2eacbd174fc2..a70e3e9be930 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -325,7 +325,7 @@ pub(crate) fn transform_bottom_unnest( let (data_type, _) = arg.data_type_and_nullable(input.schema())?; if let DataType::Struct(_) = data_type { - return internal_err!("unnest on struct can ony be applied at the root level of select expression"); + return internal_err!("unnest on struct can only be applied at the root level of select expression"); } let mut transformed_exprs = transform(&expr, arg)?; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 91295b2e8aae..aada560fd884 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; use std::vec; use arrow_schema::*; @@ -28,6 +29,7 @@ use datafusion_sql::unparser::dialect::{ }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; +use datafusion_functions::core::planner::CoreFunctionPlanner; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -155,7 +157,8 @@ fn roundtrip_statement() -> Result<()> { let context = MockContextProvider::default() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -184,7 +187,8 @@ fn roundtrip_crossjoin() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -240,6 +244,133 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + // Test query with derived tables that put distinct,sort,limit on the wrong level + TestStatementWithDialect { + sql: "SELECT j1_string from j1 order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string AS a from j1 order by j1_id", + expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + distinct j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // more tests around subquery/derived table roundtrip + TestStatementWithDialect { + sql: "SELECT string_count FROM ( + SELECT + j1_id, + MIN(j2_string) + FROM + j1 LEFT OUTER JOIN j2 ON + j1_id = j2_id + GROUP BY + j1_id + ) AS agg (id, string_count) + ", + expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, MIN(j2.j2_string) FROM j1 LEFT JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + group by + j1_id, + j1_string, + j2_string + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that order by columns are not in select columns + TestStatementWithDialect { + sql: " + SELECT + j1_string + FROM + ( + SELECT + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc, + j2.j2_id desc + limit + 10 + ) abc + ORDER BY + j2_string", + expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, ]; for query in tests { @@ -247,7 +378,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { .try_with_sql(query.sql)? .parse_statement()?; - let context = MockContextProvider::default(); + let context = MockContextProvider::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel .sql_statement_to_plan(statement) diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index b8d8bd12d28b..374aa9db6714 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -25,6 +25,7 @@ use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, GetExt, Result, TableReference}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; @@ -53,10 +54,11 @@ pub(crate) struct MockContextProvider { options: ConfigOptions, udfs: HashMap>, udafs: HashMap>, + expr_planners: Vec>, } impl MockContextProvider { - // Surpressing dead code warning, as this is used in integration test crates + // Suppressing dead code warning, as this is used in integration test crates #[allow(dead_code)] pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions { &mut self.options @@ -73,6 +75,11 @@ impl MockContextProvider { self.udafs.insert(udaf.name().to_lowercase(), udaf); self } + + pub(crate) fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } } impl ContextProvider for MockContextProvider { @@ -240,6 +247,10 @@ impl ContextProvider for MockContextProvider { fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } struct EmptyTable { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e34e7e20a0f3..3291560383df 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -18,6 +18,7 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; +use std::sync::Arc; use std::vec; use arrow_schema::TimeUnit::Nanosecond; @@ -37,6 +38,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; @@ -1144,7 +1146,7 @@ fn select_aggregate_with_group_by_with_having_that_reuses_aggregate_multiple_tim } #[test] -fn select_aggregate_with_group_by_with_having_using_aggreagate_not_in_select() { +fn select_aggregate_with_group_by_with_having_using_aggregate_not_in_select() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name @@ -1185,7 +1187,7 @@ fn select_aggregate_compound_aliased_with_group_by_with_having_referencing_compo } #[test] -fn select_aggregate_with_group_by_with_having_using_derived_column_aggreagate_not_in_select( +fn select_aggregate_with_group_by_with_having_using_derived_column_aggregate_not_in_select( ) { let sql = "SELECT first_name, MAX(age) FROM person @@ -2694,7 +2696,8 @@ fn logical_plan_with_dialect_and_options( .with_udaf(approx_median_udaf()) .with_udaf(count_udaf()) .with_udaf(avg_udaf()) - .with_udaf(grouping_udaf()); + .with_udaf(grouping_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index c7f04c0d762c..5becc75c985a 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres PG_COMPAT=true PG_URI="postgresql://postgres@127.0.0.1/postgres" cargo test --features=postgres --test sqllogictests ``` -The environemnt variables: +The environment variables: 1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion) 2. `PG_URI` contains a `libpq` style connection string, whose format is described in diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a0140b1c5292..fa228d499d1f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -129,12 +129,12 @@ query TT explain select array_agg(c1 order by c2 desc, c3) from agg_order; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] +01)Aggregate: groupBy=[[]], aggr=[[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] 02)--TableScan: agg_order projection=[c1, c2, c3] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] +01)AggregateExec: mode=Final, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] +03)----AggregateExec: mode=Partial, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] 04)------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST], preserve_partitioning=[true] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true @@ -183,7 +183,7 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, -# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs` +# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table @@ -231,8 +231,8 @@ explain with A as ( ) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; ---- logical_plan -01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1)) -02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))]] +01)Projection: array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]] 03)----SubqueryAlias: a 04)------SubqueryAlias: a 05)--------Union @@ -247,11 +247,11 @@ logical_plan 14)----------Projection: Int64(1) AS id, Int64(2) AS foo 15)------------EmptyRelation physical_plan -01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))] -02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))] +01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 -05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted 06)----------UnionExec 07)------------ProjectionExec: expr=[1 as id, 2 as foo] 08)--------------PlaceholderRowExec @@ -3705,7 +3705,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query PPPPPPPPTT select * from t; ---- @@ -3816,7 +3816,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDTT select * from t; ---- @@ -3914,7 +3914,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDDDTT select * from t; ---- @@ -4108,7 +4108,7 @@ select sum(c1), arrow_typeof(sum(c1)) from d_table; ---- 100 Decimal128(20, 3) -# aggregate sum with deciaml +# aggregate sum with decimal statement ok create table t (c decimal(35, 3)) as values (10), (null), (20); @@ -5418,6 +5418,34 @@ SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; statement ok DROP TABLE t; +# Test for CASE with NULL in aggregate function +statement ok +CREATE TABLE example(data double precision); + +statement ok +INSERT INTO example VALUES (1), (2), (NULL), (4); + +query RR +SELECT + sum(CASE WHEN data is NULL THEN NULL ELSE data+1 END) as then_null, + sum(CASE WHEN data is NULL THEN data+1 ELSE NULL END) as else_null +FROM example; +---- +10 NULL + +query R +SELECT + CASE data WHEN 1 THEN NULL WHEN 2 THEN 3.3 ELSE NULL END as case_null +FROM example; +---- +NULL +3.3 +NULL +NULL + +statement ok +drop table example; + # Test Convert FirstLast optimizer rule statement ok CREATE EXTERNAL TABLE convert_first_last_table ( diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 7917f1d78da8..f2972e4c14c2 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5690,7 +5690,7 @@ select ---- [] [] [0] [0] -# Test range for other egde cases +# Test range for other edge cases query ???????? select range(9223372036854775807, 9223372036854775807, -1) as c1, @@ -5828,7 +5828,7 @@ select [-9223372036854775808] [9223372036854775807] [0, -9223372036854775808] [0, 9223372036854775807] -# Test generate_series for other egde cases +# Test generate_series for other edge cases query ???? select generate_series(9223372036854775807, 9223372036854775807, -1) as c1, diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 621cd3e528f1..5c5f9d510e55 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -25,7 +25,7 @@ SELECT X'FF01', arrow_typeof(X'FF01'); ---- ff01 Binary -# Invaid hex values +# Invalid hex values query error DataFusion error: Error during planning: Invalid HexStringLiteral 'Z' SELECT X'Z' diff --git a/datafusion/sqllogictest/test_files/binary_view.slt b/datafusion/sqllogictest/test_files/binary_view.slt index de0f0bea7ffb..77ec77c5ecce 100644 --- a/datafusion/sqllogictest/test_files/binary_view.slt +++ b/datafusion/sqllogictest/test_files/binary_view.slt @@ -199,4 +199,4 @@ Raphael R false false true true NULL R NULL NULL NULL NULL statement ok -drop table test; \ No newline at end of file +drop table test; diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt new file mode 100644 index 000000000000..70063b88fb19 --- /dev/null +++ b/datafusion/sqllogictest/test_files/case.slt @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create test data +statement ok +create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6), (null, null), (6, null), (null, 7); + +# CASE WHEN with condition +query T +SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo +---- +one +three +? +? +? +? + +# CASE WHEN with no condition +query I +SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo +---- +2 +3 +5 +NULL +6 +7 + +# column or explicit null +query I +SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo +---- +NULL +4 +6 +NULL +NULL +7 + +# column or implicit null +query I +SELECT CASE WHEN a > 2 THEN b END FROM foo +---- +NULL +4 +6 +NULL +NULL +7 + +# scalar or scalar (string) +query T +SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo +---- +odd +even +even +odd +even +odd + +# scalar or scalar (int) +query I +SELECT CASE WHEN a > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +1 +1 +0 +1 +0 + +# predicate binary expression with scalars (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN 1 > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 + +# predicate using boolean literal (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN false THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 31e54b981a05..ff7040926caa 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -271,7 +271,7 @@ select * from validate_struct_with_array; {c0: foo, c1: [1, 2, 3], c2: {c0: bar, c1: [2, 3, 4]}} -# Copy parquet with all supported statment overrides +# Copy parquet with all supported statement overrides query IT COPY source_table TO 'test_files/scratch/copy/table_with_options/' diff --git a/datafusion/sqllogictest/test_files/count_star_rule.slt b/datafusion/sqllogictest/test_files/count_star_rule.slt new file mode 100644 index 000000000000..99d358ad17f0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/count_star_rule.slt @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE t1 (a INTEGER, b INTEGER, c INTEGER); + +statement ok +INSERT INTO t1 VALUES +(1, 2, 3), +(1, 5, 6), +(2, 3, 5); + +statement ok +CREATE TABLE t2 (a INTEGER, b INTEGER, c INTEGER); + +query TT +EXPLAIN SELECT COUNT() FROM (SELECT 1 AS a, 2 AS b) AS t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]] +02)--SubqueryAlias: t +03)----EmptyRelation +physical_plan +01)ProjectionExec: expr=[1 as count()] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT t1.a, COUNT() FROM t1 GROUP BY t1.a; +---- +logical_plan +01)Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +02)--TableScan: t1 projection=[a] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 0; +---- +logical_plan +01)Projection: t1.a, count() AS cnt +02)--Filter: count() > Int64(0) +03)----Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +04)------TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count()@1 as cnt] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: count()@1 > 0 +04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 1; +---- +1 2 + +query TT +EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1; +---- +logical_plan +01)Projection: t1.a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a] +02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a; +---- +1 2 +1 2 +2 1 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index ca3bebe79f27..f7f5aa54dd0d 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -293,3 +293,45 @@ id0 "value0" id1 "value1" id2 "value2" id3 "value3" + +# Handling of newlines in values + +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_unsafe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv'; + +statement error incorrect number of fields +select * from stored_table_with_newlines_in_values_unsafe; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_safe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv' +OPTIONS ('format.newlines_in_values' 'true'); + +query TT +select * from stored_table_with_newlines_in_values_safe; +---- +id message +1 +01)hello +02)world +2 +01)something +02)else +3 +01) +02)many +03)lines +04)make +05)good test +4 unquoted +value end diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index fa25f00974a9..be7fdac71b57 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -103,10 +103,6 @@ SELECT power(1, 2, 3); # Wrong window/aggregate function signature # -# AggregateFunction with wrong number of arguments -query error -select count(); - # AggregateFunction with wrong number of arguments query error select avg(c1, c12) from aggregate_test_100; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4e8072bbc7..172cbad44dca 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -394,7 +394,7 @@ physical_plan_with_schema statement ok set datafusion.execution.collect_statistics = false; -# Explain ArrayFuncions +# Explain ArrayFunctions statement ok set datafusion.explain.physical_plan_only = false diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index b2be65a609e3..a3cc10e1eeb8 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -2289,10 +2289,10 @@ FROM annotated_data_infinite2 GROUP BY a, b; ---- logical_plan -01)Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[ARRAY_AGG(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]]] +01)Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[array_agg(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]]] 02)--TableScan: annotated_data_infinite2 projection=[a, b, d] physical_plan -01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]], ordering_mode=Sorted +01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[array_agg(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]], ordering_mode=Sorted 02)--PartialSortExec: expr=[a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,d@2 ASC NULLS LAST], common_prefix_length=[2] 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST] @@ -2459,12 +2459,12 @@ EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2488,13 +2488,13 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)] +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)] 03)----SortExec: expr=[amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2531,14 +2531,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------Sort: sales_global.country ASC NULLS LAST 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted 03)----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2567,14 +2567,14 @@ EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) GROUP BY s.country, s.zip_code ---- logical_plan -01)Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] +01)Projection: s.country, s.zip_code, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------Sort: sales_global.country ASC NULLS LAST 05)--------TableScan: sales_global projection=[zip_code, country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, sum(s.amount)@3 as sum1] -02)--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=PartiallySorted([0]) +01)ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, sum(s.amount)@3 as sum1] +02)--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=PartiallySorted([0]) 03)----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2603,14 +2603,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------Sort: sales_global.country ASC NULLS LAST 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted 03)----SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2638,14 +2638,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount D GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------Sort: sales_global.country ASC NULLS LAST 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted 03)----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2672,12 +2672,12 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] 03)----SortExec: expr=[amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2703,12 +2703,12 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2735,12 +2735,12 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2765,12 +2765,12 @@ EXPLAIN SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[sum(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[sum(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, ts, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[country@0 as country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----SortExec: expr=[amount@2 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -3036,14 +3036,14 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts ASC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 05)--------SortExec: expr=[ts@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3060,14 +3060,14 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts DESC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 05)--------SortExec: expr=[ts@0 DESC], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3084,14 +3084,14 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 05)--------SortExec: expr=[amount@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3111,17 +3111,17 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +02)--Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 04)------TableScan: sales_global projection=[country, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] 02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +03)----ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 08)--------------SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[true] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3147,17 +3147,17 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +02)--Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 04)------TableScan: sales_global projection=[country, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] 02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +03)----ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] 08)--------------SortExec: expr=[amount@1 DESC], preserve_partitioning=[true] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -4971,10 +4971,10 @@ ORDER BY a, b; ---- logical_plan 01)Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST -02)--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +02)--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[array_agg(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] 03)----TableScan: multiple_ordered_table projection=[a, b, c] physical_plan -01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]], ordering_mode=Sorted +01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[array_agg(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]], ordering_mode=Sorted 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query II? diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 95bea1223a9c..ddacf1cc6a79 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -168,6 +168,7 @@ datafusion.catalog.format NULL datafusion.catalog.has_header false datafusion.catalog.information_schema true datafusion.catalog.location NULL +datafusion.catalog.newlines_in_values false datafusion.execution.aggregate.scalar_update_factor 10 datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true @@ -183,17 +184,17 @@ datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL datafusion.execution.parquet.bloom_filter_on_read true datafusion.execution.parquet.bloom_filter_on_write false -datafusion.execution.parquet.column_index_truncate_length NULL +datafusion.execution.parquet.column_index_truncate_length 64 datafusion.execution.parquet.compression zstd(3) datafusion.execution.parquet.created_by datafusion -datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 +datafusion.execution.parquet.data_page_row_count_limit 20000 datafusion.execution.parquet.data_pagesize_limit 1048576 -datafusion.execution.parquet.dictionary_enabled NULL +datafusion.execution.parquet.dictionary_enabled true datafusion.execution.parquet.dictionary_page_size_limit 1048576 datafusion.execution.parquet.enable_page_index true datafusion.execution.parquet.encoding NULL datafusion.execution.parquet.max_row_group_size 1048576 -datafusion.execution.parquet.max_statistics_size NULL +datafusion.execution.parquet.max_statistics_size 4096 datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 datafusion.execution.parquet.maximum_parallel_row_group_writers 1 datafusion.execution.parquet.metadata_size_hint NULL @@ -252,6 +253,7 @@ datafusion.catalog.format NULL Type of `TableProvider` to use when loading `defa datafusion.catalog.has_header false Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information datafusion.catalog.location NULL Location scanned to load tables for `default` schema +datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting @@ -262,32 +264,32 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. -datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. -datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_on_read true Use any available bloom filters when reading parquet files -datafusion.execution.parquet.bloom_filter_on_write false Write bloom filters for all columns when creating parquet files -datafusion.execution.parquet.column_index_truncate_length NULL Sets column index truncate length -datafusion.execution.parquet.compression zstd(3) Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.created_by datafusion Sets "created by" property -datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 Sets best effort maximum number of rows in data page -datafusion.execution.parquet.data_pagesize_limit 1048576 Sets best effort maximum size of data page in bytes -datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting -datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes -datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. -datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.max_row_group_size 1048576 Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. -datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting -datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer -datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file -datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". -datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query -datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata -datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.write_batch_size 1024 Sets write_batch_size in bytes -datafusion.execution.parquet.writer_version 1.0 Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files +datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length +datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. +datafusion.execution.parquet.created_by datafusion (writing) Sets "created by" property +datafusion.execution.parquet.data_page_row_count_limit 20000 (writing) Sets best effort maximum number of rows in data page +datafusion.execution.parquet.data_pagesize_limit 1048576 (writing) Sets best effort maximum size of data page in bytes +datafusion.execution.parquet.dictionary_enabled true (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting +datafusion.execution.parquet.dictionary_page_size_limit 1048576 (writing) Sets best effort maximum dictionary page size, in bytes +datafusion.execution.parquet.enable_page_index true (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. +datafusion.execution.parquet.encoding NULL (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_row_group_size 1048576 (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. +datafusion.execution.parquet.max_statistics_size 4096 (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.metadata_size_hint NULL (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer +datafusion.execution.parquet.pruning true (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file +datafusion.execution.parquet.pushdown_filters false (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". +datafusion.execution.parquet.reorder_filters false (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query +datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata +datafusion.execution.parquet.statistics_enabled NULL (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes +datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. @@ -368,9 +370,12 @@ datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. # show empty verbose -query TTT +statement error DataFusion error: Error during planning: '' is not a variable which can be viewed with 'SHOW' SHOW VERBOSE ----- + +# show nonsense verbose +statement error DataFusion error: Error during planning: 'nonsense' is not a variable which can be viewed with 'SHOW' +SHOW NONSENSE VERBOSE # information_schema_describe_table @@ -506,9 +511,7 @@ SHOW columns from datafusion.public.t2 # show_non_existing_variable -# FIXME -# currently we cannot know whether a variable exists, this will output 0 row instead -statement ok +statement error DataFusion error: Error during planning: 'something_unknown' is not a variable which can be viewed with 'SHOW' SHOW SOMETHING_UNKNOWN; statement ok diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index eab4eed00269..afb262cf95a5 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -325,7 +325,7 @@ select ---- Interval(MonthDayNano) Interval(MonthDayNano) -# cast with explicit cast sytax +# cast with explicit cast syntax query TT select arrow_typeof(cast ('5 months' as interval)), diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index efebba1779cf..4f01d2b2c72b 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1054,3 +1054,45 @@ DROP TABLE t1; statement ok DROP TABLE t2; + +# Join Using Issue with Cast Expr +# Found issue: https://github.com/apache/datafusion/issues/11412 + +statement ok +/*DML*/CREATE TABLE t60(v0 BIGINT, v1 BIGINT, v2 BOOLEAN, v3 BOOLEAN); + +statement ok +/*DML*/CREATE TABLE t0(v0 DOUBLE, v1 BIGINT); + +statement ok +/*DML*/CREATE TABLE t1(v0 DOUBLE); + +query I +SELECT COUNT(*) +FROM t1 +NATURAL JOIN t60 +INNER JOIN t0 +ON t60.v1 = t0.v0 +AND t0.v1 > t60.v1; +---- +0 + +query I +SELECT COUNT(*) +FROM t1 +JOIN t60 +USING (v0) +INNER JOIN t0 +ON t60.v1 = t0.v0 +AND t0.v1 > t60.v1; +---- +0 + +statement ok +DROP TABLE t60; + +statement ok +DROP TABLE t0; + +statement ok +DROP TABLE t1; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index b9897f81a107..441ccb7d99d5 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3532,7 +3532,7 @@ physical_plan 03)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 04)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# Currently datafusion can pushdown filter conditions with scalar UDF into # cross join. query TT EXPLAIN SELECT * @@ -3540,19 +3540,16 @@ FROM annotated_data as t1, annotated_data as t2 WHERE EXAMPLE(t1.a, t2.a) > 3 ---- logical_plan -01)Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) -02)--CrossJoin: -03)----SubqueryAlias: t1 -04)------TableScan: annotated_data projection=[a0, a, b, c, d] -05)----SubqueryAlias: t2 -06)------TableScan: annotated_data projection=[a0, a, b, c, d] +01)Inner Join: Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +02)--SubqueryAlias: t1 +03)----TableScan: annotated_data projection=[a0, a, b, c, d] +04)--SubqueryAlias: t2 +05)----TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan -01)CoalesceBatchesExec: target_batch_size=2 -02)--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 -03)----CrossJoinExec -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -05)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -06)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +01)NestedLoopJoinExec: join_type=Inner, filter=example(CAST(a@0 AS Float64), CAST(a@1 AS Float64)) > 3 +02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true #### # Config teardown diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index fb8917a5f4fe..e530e14df66e 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -131,17 +131,23 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); ---- {[1, 2]: [a, b], [3, 4]: [b]} -query error +query ? SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +---- +{POST: 41, HEAD: ab, PATCH: 30} query error SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); -query error -SELECT MAKE_MAP('POST', 41, 123, 33,'PATCH', 30); +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +---- +{POST: 41, HEAD: ab, PATCH: 30} -query error +query ? SELECT MAKE_MAP() +---- +{} query error SELECT MAKE_MAP('POST', 41, 'HEAD'); @@ -296,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), a {POST: 41, HEAD: 33, PATCH: 30} {POST: 41, HEAD: 33, PATCH: 30} {POST: 41, HEAD: 33, PATCH: 30} + + +query ? +VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1])) +---- +{a: 1} +{b: 2} +{c: 3, a: 1} diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 6ff804c3065d..6884d762612d 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -112,7 +112,7 @@ SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL -# abs: empty argumnet +# abs: empty argument statement error SELECT abs(); diff --git a/datafusion/sqllogictest/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt index 9f4710eb9bcc..9bd3023b56f7 100644 --- a/datafusion/sqllogictest/test_files/misc.slt +++ b/datafusion/sqllogictest/test_files/misc.slt @@ -30,6 +30,10 @@ query I select 1 where NULL ---- +# Where clause does not accept non boolean and has nice error message +query error Cannot create filter with non\-boolean predicate 'Utf8\("foo"\)' returning Utf8 +select 1 where 'foo' + query I select 1 where NULL and 1 = 1 ---- diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index ba9eedcbbd34..aafaa054964e 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -42,7 +42,7 @@ physical_plan statement ok set datafusion.execution.coalesce_batches = false -# expect no coalsece +# expect no coalescence query TT explain SELECT * FROM a WHERE c0 < 1; ---- diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 51de40fb1972..d0a6d6adc107 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -326,6 +326,13 @@ select column1 + column2 from foo group by column1, column2 ORDER BY column2 des 7 3 +# Test issue: https://github.com/apache/datafusion/issues/11549 +query I +select column1 from foo order by log(column2); +---- +1 +3 +5 # Cleanup statement ok @@ -512,7 +519,7 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( ) STORED AS CSV WITH ORDER(c11) -WITH ORDER(c12 DESC) +WITH ORDER(c12 DESC NULLS LAST) LOCATION '../../testing/data/csv/aggregate_test_100.csv' OPTIONS ('format.has_header' 'true'); @@ -547,34 +554,34 @@ physical_plan 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true query TT - EXPLAIN SELECT LOG(c11, c12) as log_c11_base_c12 + EXPLAIN SELECT LOG(c12, c11) as log_c11_base_c12 FROM aggregate_test_100 ORDER BY log_c11_base_c12; ---- logical_plan 01)Sort: log_c11_base_c12 ASC NULLS LAST -02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c11_base_c12 +02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c11_base_c12 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan 01)SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] +02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c11_base_c12] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true query TT -EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 +EXPLAIN SELECT LOG(c11, c12) as log_c12_base_c11 FROM aggregate_test_100 -ORDER BY log_c12_base_c11 DESC; +ORDER BY log_c12_base_c11 DESC NULLS LAST; ---- logical_plan -01)Sort: log_c12_base_c11 DESC NULLS FIRST -02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c12_base_c11 +01)Sort: log_c12_base_c11 DESC NULLS LAST +02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c12_base_c11 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan -01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC] -02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] +01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC NULLS LAST] +02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c12_base_c11] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true statement ok drop table aggregate_test_100; @@ -1132,3 +1139,10 @@ physical_plan 02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c] 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true + +# Union a query with the actual data and one with a constant +query I +SELECT (SELECT c from ordered_table ORDER BY c LIMIT 1) UNION ALL (SELECT 23 as c from ordered_table ORDER BY c LIMIT 1) ORDER BY c; +---- +0 +23 diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 5daa9333fb36..ff9afa94f40a 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1203,7 +1203,7 @@ FROM t1 999 999 -# case_when_else_with_null_contant() +# case_when_else_with_null_constant() query I SELECT CASE WHEN c1 = 'a' THEN 1 @@ -1238,27 +1238,27 @@ SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END bar # case_expr_with_null() -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a; ---- NULL 3 -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a; ---- 1 3 # case_expr_with_nulls() -query ? +query I select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a ---- NULL NULL 4 -query ? +query I select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a; ---- NULL @@ -1982,3 +1982,37 @@ query I select strpos('joséésoj', arrow_cast(null, 'Utf8')); ---- NULL + +statement ok +CREATE TABLE t1 (v1 int) AS VALUES (1), (2), (3); + +query I +SELECT * FROM t1 ORDER BY ACOS(SIN(v1)); +---- +2 +1 +3 + +query I +SELECT * FROM t1 ORDER BY ACOSH(SIN(v1)); +---- +1 +2 +3 + +query I +SELECT * FROM t1 ORDER BY ASIN(SIN(v1)); +---- +3 +1 +2 + +query I +SELECT * FROM t1 ORDER BY ATANH(SIN(v1)); +---- +3 +1 +2 + +statement ok +drop table t1; diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 03426dec874f..6884efc07e15 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -613,6 +613,33 @@ END; ---- 2 +# select case when type is null +query I +select CASE + WHEN NULL THEN 1 + ELSE 2 +END; +---- +2 + +# select case then type is null +query I +select CASE + WHEN 10 > 5 THEN NULL + ELSE 2 +END; +---- +NULL + +# select case else type is null +query I +select CASE + WHEN 10 = 5 THEN 1 + ELSE NULL +END; +---- +NULL + # Binary Expression for LargeUtf8 # issue: https://github.com/apache/datafusion/issues/5893 statement ok diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f4e492649b9f..2ca2d49997a6 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -1161,7 +1161,7 @@ ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 -# Test date trun on different granularity +# Test date turn on different granularity query TP rowsort SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_nanos UNION ALL diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 31b16f975e9e..2dc8385bf191 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -602,3 +602,48 @@ physical_plan 09)--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] 10)----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] 11)------PlaceholderRowExec + + +# Test issue: https://github.com/apache/datafusion/issues/11409 +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT, v2 BIGINT, v3 BOOLEAN); + +statement ok +CREATE TABLE t2(v0 DOUBLE); + +query I +INSERT INTO t1(v0, v2, v1) VALUES (-1229445667, -342312412, -1507138076); +---- +1 + +query I +INSERT INTO t1(v0, v1) VALUES (1541512604, -1229445667); +---- +1 + +query I +INSERT INTO t1(v1, v3, v0, v2) VALUES (-1020641465, false, -1493773377, 1751276473); +---- +1 + +query I +INSERT INTO t1(v3) VALUES (true), (true), (false); +---- +3 + +query I +INSERT INTO t2(v0) VALUES (0.28014577292925047); +---- +1 + +query II +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL); +---- + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 93146541e107..d818c0e92795 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -484,7 +484,7 @@ query error DataFusion error: type_coercion\ncaused by\nThis feature is not impl select sum(unnest(generate_series(1,10))); ## TODO: support unnest as a child expr -query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; @@ -517,7 +517,7 @@ select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_tabl 3 [{c0: [2], c1: [[3], [4]]}] 4 [{c0: [2], c1: [[3], [4]]}] -## tripple list unnest +## triple list unnest query I? select unnest(unnest(unnest(column2))), column2 from recursive_unnest_table; ---- diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 49b2bd9aa0b5..3d455d7a88ca 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -74,7 +74,7 @@ logical_plan statement ok create table t3(a int, b varchar, c double, d int); -# set from mutiple tables, sqlparser only supports from one table +# set from multiple tables, sqlparser only supports from one table query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 5296f13de08a..e9d417c93a57 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2040,16 +2040,16 @@ query TT EXPLAIN SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) ---- logical_plan -01)Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(aggregate_test_100.c13)]] +01)Projection: array_agg(aggregate_test_100.c13) AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(aggregate_test_100.c13)]] 03)----Limit: skip=0, fetch=1 04)------Sort: aggregate_test_100.c13 ASC NULLS LAST, fetch=1 05)--------TableScan: aggregate_test_100 projection=[c13] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] +01)ProjectionExec: expr=[array_agg(aggregate_test_100.c13)@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(aggregate_test_100.c13)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(aggregate_test_100.c13)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 06)----------GlobalLimitExec: skip=0, fetch=1 07)------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST], preserve_partitioning=[false] @@ -3968,7 +3968,7 @@ CREATE TABLE table_with_pk ( # However, if we know that contains a unique column (e.g. a PRIMARY KEY), # it can be treated as `OVER (ORDER BY ROWS BETWEEN UNBOUNDED PRECEDING # AND CURRENT ROW)` where window frame units change from `RANGE` to `ROWS`. This -# conversion makes the window frame manifestly causal by eliminating the possiblity +# conversion makes the window frame manifestly causal by eliminating the possibility # of ties explicitly (see window frame documentation for a discussion of causality # in this context). The Query below should have `ROWS` in its window frame. query TT diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs new file mode 100644 index 000000000000..459d0e0c5ae5 --- /dev/null +++ b/datafusion/substrait/src/extensions.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{plan_err, DataFusionError}; +use std::collections::HashMap; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; + +/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define +/// behavior of plans in addition to what's supported directly by the protobuf definitions. +/// That includes functions, but also provides support for custom types and variations for existing +/// types. This structs facilitates the use of these extensions in DataFusion. +/// TODO: DF doesn't yet use extensions for type variations +/// TODO: DF doesn't yet provide valid extensionUris +#[derive(Default, Debug, PartialEq)] +pub struct Extensions { + pub functions: HashMap, // anchor -> function name + pub types: HashMap, // anchor -> type name + pub type_variations: HashMap, // anchor -> type variation name +} + +impl Extensions { + /// Registers a function and returns the anchor (reference) to it. If the function has already + /// been registered, it returns the existing anchor. + /// Function names are case-insensitive (converted to lowercase). + pub fn register_function(&mut self, function_name: String) -> u32 { + let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + + match self.functions.iter().find(|(_, f)| *f == &function_name) { + Some((function_anchor, _)) => *function_anchor, // Function has been registered + None => { + // Function has NOT been registered + let function_anchor = self.functions.len() as u32; + self.functions + .insert(function_anchor, function_name.clone()); + function_anchor + } + } + } + + /// Registers a type and returns the anchor (reference) to it. If the type has already + /// been registered, it returns the existing anchor. + pub fn register_type(&mut self, type_name: String) -> u32 { + let type_name = type_name.to_lowercase(); + match self.types.iter().find(|(_, t)| *t == &type_name) { + Some((type_anchor, _)) => *type_anchor, // Type has been registered + None => { + // Type has NOT been registered + let type_anchor = self.types.len() as u32; + self.types.insert(type_anchor, type_name.clone()); + type_anchor + } + } + } +} + +impl TryFrom<&Vec> for Extensions { + type Error = DataFusionError; + + fn try_from( + value: &Vec, + ) -> datafusion::common::Result { + let mut functions = HashMap::new(); + let mut types = HashMap::new(); + let mut type_variations = HashMap::new(); + + for ext in value { + match &ext.mapping_type { + Some(MappingType::ExtensionFunction(ext_f)) => { + functions.insert(ext_f.function_anchor, ext_f.name.to_owned()); + } + Some(MappingType::ExtensionType(ext_t)) => { + types.insert(ext_t.type_anchor, ext_t.name.to_owned()); + } + Some(MappingType::ExtensionTypeVariation(ext_v)) => { + type_variations + .insert(ext_v.type_variation_anchor, ext_v.name.to_owned()); + } + None => return plan_err!("Cannot parse empty extension"), + } + } + + Ok(Extensions { + functions, + types, + type_variations, + }) + } +} + +impl From for Vec { + fn from(val: Extensions) -> Vec { + let mut extensions = vec![]; + for (f_anchor, f_name) in val.functions { + let function_extension = ExtensionFunction { + extension_uri_reference: u32::MAX, + function_anchor: f_anchor, + name: f_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(function_extension)), + }; + extensions.push(simple_extension); + } + + for (t_anchor, t_name) in val.types { + let type_extension = ExtensionType { + extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545 + type_anchor: t_anchor, + name: t_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(type_extension)), + }; + extensions.push(simple_extension); + } + + for (tv_anchor, tv_name) in val.type_variations { + let type_variation_extension = ExtensionTypeVariation { + extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet + type_variation_anchor: tv_anchor, + name: tv_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionTypeVariation( + type_variation_extension, + )), + }; + extensions.push(simple_extension); + } + + extensions + } +} diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 454f0e7b7cb9..0b1c796553c0 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -72,6 +72,7 @@ //! # Ok(()) //! # } //! ``` +pub mod extensions; pub mod logical_plan; pub mod physical_plan; pub mod serializer; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1365630d5079..15c447114819 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer}; use async_recursion::async_recursion; -use datafusion::arrow::array::GenericListArray; +use datafusion::arrow::array::{GenericListArray, MapArray}; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; @@ -36,16 +36,22 @@ use datafusion::logical_expr::{ use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; +use crate::extensions::Extensions; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, }; +#[allow(deprecated)] +use crate::variation_const::{ + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, +}; +use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -65,7 +71,9 @@ use std::str::FromStr; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{IntervalDayToSecond, IntervalYearToMonth}; +use substrait::proto::expression::literal::{ + IntervalDayToSecond, IntervalYearToMonth, UserDefined, +}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; @@ -78,7 +86,6 @@ use substrait::proto::{ window_function::bound::Kind as BoundKind, window_function::Bound, window_function::BoundsType, MaskExpression, RexType, }, - extensions::simple_extension_declaration::MappingType, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::ReadType, @@ -185,19 +192,10 @@ pub async fn from_substrait_plan( plan: &Plan, ) -> Result { // Register function extension - let function_extension = plan - .extensions - .iter() - .map(|e| match &e.mapping_type { - Some(ext) => match ext { - MappingType::ExtensionFunction(ext_f) => { - Ok((ext_f.function_anchor, &ext_f.name)) - } - _ => not_impl_err!("Extension type not supported: {ext:?}"), - }, - None => not_impl_err!("Cannot parse empty extension"), - }) - .collect::>>()?; + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } // Parse relations match plan.relations.len() { @@ -205,10 +203,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -396,7 +394,7 @@ fn make_renamed_schema( pub async fn from_substrait_rel( ctx: &SessionContext, rel: &Rel, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &rel.rel_type { Some(RelType::Project(p)) => { @@ -660,7 +658,7 @@ pub async fn from_substrait_rel( substrait_datafusion_err!("No base schema provided for Virtual Table") })?; - let schema = from_substrait_named_struct(base_schema)?; + let schema = from_substrait_named_struct(base_schema, extensions)?; if vt.values.is_empty() { return Ok(LogicalPlan::EmptyRelation(EmptyRelation { @@ -681,6 +679,7 @@ pub async fn from_substrait_rel( name_idx += 1; // top-level names are provided through schema Ok(Expr::Literal(from_substrait_literal( lit, + extensions, &base_schema.names, &mut name_idx, )?)) @@ -892,7 +891,7 @@ pub async fn from_substrait_sorts( ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { @@ -942,7 +941,7 @@ pub async fn from_substrait_rex_vec( ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { @@ -957,7 +956,7 @@ pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { @@ -977,7 +976,7 @@ pub async fn from_substrait_agg_func( ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, filter: Option>, order_by: Option>, distinct: bool, @@ -985,14 +984,14 @@ pub async fn from_substrait_agg_func( let args = from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; - let Some(function_name) = extensions.get(&f.function_reference) else { + let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; - let function_name = substrait_fun_name((**function_name).as_str()); + let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments @@ -1025,7 +1024,7 @@ pub async fn from_substrait_rex( ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match &e.rex_type { Some(RexType::SingularOrList(s)) => { @@ -1105,7 +1104,7 @@ pub async fn from_substrait_rex( })) } Some(RexType::ScalarFunction(f)) => { - let Some(fn_name) = extensions.get(&f.function_reference) else { + let Some(fn_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Scalar function not found: function reference = {:?}", f.function_reference @@ -1155,7 +1154,7 @@ pub async fn from_substrait_rex( } } Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal_without_names(lit)?; + let scalar_value = from_substrait_literal_without_names(lit, extensions)?; Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { @@ -1169,12 +1168,13 @@ pub async fn from_substrait_rex( ) .await?, ), - from_substrait_type_without_names(output_type)?, + from_substrait_type_without_names(output_type, extensions)?, ))), None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let Some(fn_name) = extensions.get(&window.function_reference) else { + let Some(fn_name) = extensions.functions.get(&window.function_reference) + else { return plan_err!( "Window function not found: function reference = {:?}", window.function_reference @@ -1328,12 +1328,16 @@ pub async fn from_substrait_rex( } } -pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result { - from_substrait_type(dt, &[], &mut 0) +pub(crate) fn from_substrait_type_without_names( + dt: &Type, + extensions: &Extensions, +) -> Result { + from_substrait_type(dt, extensions, &[], &mut 0) } fn from_substrait_type( dt: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1416,7 +1420,7 @@ fn from_substrait_type( substrait_datafusion_err!("List type must have inner type") })?; let field = Arc::new(Field::new_list_field( - from_substrait_type(inner_type, dfs_names, name_idx)?, + from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, @@ -1438,29 +1442,22 @@ fn from_substrait_type( })?; let key_field = Arc::new(Field::new( "key", - from_substrait_type(key_type, dfs_names, name_idx)?, + from_substrait_type(key_type, extensions, dfs_names, name_idx)?, false, )); let value_field = Arc::new(Field::new( "value", - from_substrait_type(value_type, dfs_names, name_idx)?, + from_substrait_type(value_type, extensions, dfs_names, name_idx)?, true, )); - match map.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), - false, - )) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - )?, - } + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { @@ -1490,28 +1487,41 @@ fn from_substrait_type( ), }, r#type::Kind::UserDefined(u) => { - match u.type_reference { - // Kept for backwards compatibility, use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) + if let Some(name) = extensions.types.get(&u.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), } - // Not supported yet by Substrait - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( + } else { + // Kept for backwards compatibility, new plans should include the extension instead + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Not supported yet by Substrait + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), + } } } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - s, dfs_names, name_idx, + s, extensions, dfs_names, name_idx, )?)), r#type::Kind::Varchar(_) => Ok(DataType::Utf8), r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), @@ -1523,6 +1533,7 @@ fn from_substrait_type( fn from_substrait_struct_type( s: &r#type::Struct, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1530,7 +1541,7 @@ fn from_substrait_struct_type( for (i, f) in s.types.iter().enumerate() { let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(f, dfs_names, name_idx)?, + from_substrait_type(f, extensions, dfs_names, name_idx)?, true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); @@ -1556,12 +1567,16 @@ fn next_struct_field_name( } } -fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result { +fn from_substrait_named_struct( + base_schema: &NamedStruct, + extensions: &Extensions, +) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( base_schema.r#struct.as_ref().ok_or_else(|| { substrait_datafusion_err!("Named struct must contain a struct") })?, + extensions, &base_schema.names, &mut name_idx, ); @@ -1621,12 +1636,16 @@ fn from_substrait_bound( } } -pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result { - from_substrait_literal(lit, &vec![], &mut 0) +pub(crate) fn from_substrait_literal_without_names( + lit: &Literal, + extensions: &Extensions, +) -> Result { + from_substrait_literal(lit, extensions, &vec![], &mut 0) } fn from_substrait_literal( lit: &Literal, + extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, ) -> Result { @@ -1718,11 +1737,23 @@ fn from_substrait_literal( ) } Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; let elements = l .values .iter() - .map(|el| from_substrait_literal(el, dfs_names, name_idx)) + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal( + el, + extensions, + dfs_names, + &mut element_name_idx, + ) + }) .collect::>>()?; + *name_idx = element_name_idx; if elements.is_empty() { return substrait_err!( "Empty list must be encoded as EmptyList literal type, not List" @@ -1744,6 +1775,7 @@ fn from_substrait_literal( Some(LiteralType::EmptyList(l)) => { let element_type = from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?; @@ -1759,11 +1791,89 @@ fn from_substrait_literal( } } } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + kv.key.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + kv.value.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; + let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } Some(LiteralType::Struct(s)) => { let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(field, dfs_names, name_idx)?; + let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); @@ -1771,7 +1881,7 @@ fn from_substrait_literal( builder.build()? } Some(LiteralType::Null(ntype)) => { - from_substrait_null(ntype, dfs_names, name_idx)? + from_substrait_null(ntype, extensions, dfs_names, name_idx)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -1786,40 +1896,9 @@ fn from_substrait_literal( } Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { - match user_defined.type_reference { - // Kept for backwards compatibility, use IntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice))) - } - // Kept for backwards compatibility, use IntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &UserDefined| -> Result { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval month day nano value is empty"); }; @@ -1834,17 +1913,76 @@ fn from_substrait_literal( let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); let nanoseconds = i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months, - days, - nanoseconds, - })) - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {}", - user_defined.type_reference + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = extensions.types.get(&user_defined.type_reference) { + match name.as_ref() { + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name ) + } + } + } else { + // Kept for backwards compatibility - new plans should include extension instead + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, use IntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, use IntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } } } } @@ -1856,6 +1994,7 @@ fn from_substrait_literal( fn from_substrait_null( null_type: &Type, + extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1940,6 +2079,7 @@ fn from_substrait_null( let field = Field::new_list_field( from_substrait_type( l.r#type.clone().unwrap().as_ref(), + extensions, dfs_names, name_idx, )?, @@ -1957,8 +2097,32 @@ fn from_substrait_null( ), } } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + + let key_type = + from_substrait_type(key_type, extensions, dfs_names, name_idx)?; + let value_type = + from_substrait_type(value_type, extensions, dfs_names, name_idx)?; + let entries_field = Arc::new(Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + )); + + DataType::Map(entries_field, false /* keys sorted */).try_into() + } r#type::Kind::Struct(s) => { - let fields = from_substrait_struct_type(s, dfs_names, name_idx)?; + let fields = + from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), @@ -2012,7 +2176,7 @@ impl BuiltinExprBuilder { ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { @@ -2037,7 +2201,7 @@ impl BuiltinExprBuilder { fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); @@ -2071,7 +2235,7 @@ impl BuiltinExprBuilder { case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7849d0bd431e..8263209ffccc 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -16,7 +16,6 @@ // under the License. use itertools::Itertools; -use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -33,6 +32,16 @@ use datafusion::{ scalar::ScalarValue, }; +use crate::extensions::Extensions; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, @@ -48,8 +57,10 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct, UserDefined, + user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, Struct, + UserDefined, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -72,10 +83,6 @@ use substrait::{ ScalarFunction, SingularOrList, Subquery, WindowFunction as SubstraitWindowFunction, }, - extensions::{ - self, - simple_extension_declaration::{ExtensionFunction, MappingType}, - }, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::{NamedTable, ReadType}, @@ -90,39 +97,24 @@ use substrait::{ version, }; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_URL, - LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { + let mut extensions = Extensions::default(); // Parse relation nodes - let mut extension_info: ( - Vec, - HashMap, - ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?), - names: to_substrait_named_struct(plan.schema())?.names, + input: Some(*to_substrait_rel(plan, ctx, &mut extensions)?), + names: to_substrait_named_struct(plan.schema(), &mut extensions)?.names, })), }]; - let (function_extensions, _) = extension_info; - // Return parsed plan Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], - extensions: function_extensions, + extensions: extensions.into(), relations: plan_rels, advanced_extensions: None, expected_type_urls: vec![], @@ -133,10 +125,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match plan { LogicalPlan::TableScan(scan) => { @@ -187,7 +176,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), + base_schema: Some(to_substrait_named_struct(&e.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -206,10 +195,10 @@ pub fn to_substrait_rel( let fields = row .iter() .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), Expr::Alias(alias) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(sv), + Expr::Literal(sv) => to_substrait_literal(sv, extensions), _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), @@ -225,7 +214,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), + base_schema: Some(to_substrait_named_struct(&v.schema, extensions)?), filter: None, best_effort_filter: None, projection: None, @@ -238,25 +227,25 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: None, - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extension_info)?), + input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; let filter_expr = to_substrait_rex( ctx, &filter.predicate, filter.input.schema(), 0, - extension_info, + extensions, )?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { @@ -268,7 +257,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` let limit_fetch = limit.fetch.unwrap_or(usize::MAX); Ok(Box::new(Rel { @@ -282,13 +271,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| { - substrait_sort_field(ctx, e, sort.input.schema(), extension_info) - }) + .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -300,19 +287,17 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; let groupings = to_substrait_groupings( ctx, &agg.group_expr, agg.input.schema(), - extension_info, + extensions, )?; let measures = agg .aggr_expr .iter() - .map(|e| { - to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) - }) + .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { @@ -327,7 +312,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -346,8 +331,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -364,7 +349,7 @@ pub fn to_substrait_rel( filter, &Arc::new(in_join_schema), 0, - extension_info, + extensions, )?), None => None, }; @@ -382,7 +367,7 @@ pub fn to_substrait_rel( eq_op, join.left.schema(), join.right.schema(), - extension_info, + extensions, )?; // create conjunction between `join_on` and `join_filter` to embed all join conditions, @@ -393,7 +378,7 @@ pub fn to_substrait_rel( on_expr, filter, Operator::And, - extension_info, + extensions, ))), None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist }, @@ -421,8 +406,8 @@ pub fn to_substrait_rel( right, schema: _, } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(right.as_ref(), ctx, extensions)?; Ok(Box::new(Rel { rel_type: Some(RelType::Cross(Box::new(CrossRel { common: None, @@ -435,13 +420,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extension_info) + to_substrait_rel(alias.input.as_ref(), ctx, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -456,7 +441,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; // If the input is a Project relation, we can just append the WindowFunction expressions // before returning // Otherwise, wrap the input in a Project relation before appending the WindowFunction @@ -484,7 +469,7 @@ pub fn to_substrait_rel( expr, window.input.schema(), 0, - extension_info, + extensions, )?); } // Append parsed WindowFunction expressions @@ -494,8 +479,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = - to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -553,7 +537,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extension_info)) + .map(|plan| to_substrait_rel(plan, ctx, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -579,7 +563,10 @@ pub fn to_substrait_rel( } } -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { +fn to_substrait_named_struct( + schema: &DFSchemaRef, + extensions: &mut Extensions, +) -> Result { // Substrait wants a list of all field names, including nested fields from structs, // also from within e.g. lists and maps. However, it does not want the list and map field names // themselves - only proper structs fields are considered to have useful names. @@ -624,7 +611,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type(f.data_type(), f.is_nullable(), extensions)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, @@ -642,30 +629,27 @@ fn to_substrait_join_expr( eq_op: Operator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index - extension_info, + extensions, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) }); Ok(join_expr) } @@ -722,14 +706,11 @@ pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -740,10 +721,7 @@ pub fn to_substrait_groupings( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match exprs.len() { 1 => match &exprs[0] { @@ -753,9 +731,7 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) - }) + .map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions)) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -766,23 +742,17 @@ pub fn to_substrait_groupings( .iter() .rev() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) + parse_flat_grouping_exprs(ctx, set, schema, extensions) }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, - exprs, - schema, - extension_info, + ctx, exprs, schema, extensions, )?]), } } @@ -792,25 +762,22 @@ pub fn to_substrait_agg_measure( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -826,22 +793,22 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -857,7 +824,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) @@ -866,7 +833,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -881,10 +848,7 @@ fn to_substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(sort) => { @@ -900,7 +864,7 @@ fn to_substrait_sort_field( sort.expr.deref(), schema, 0, - extension_info, + extensions, )?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) @@ -909,67 +873,15 @@ fn to_substrait_sort_field( } } -fn register_function( - function_name: String, - extension_info: &mut ( - Vec, - HashMap, - ), -) -> u32 { - let (function_extensions, function_set) = extension_info; - let function_name = function_name.to_lowercase(); - - // Some functions are named differently in Substrait default extensions than in DF - // Rename those to match the Substrait extensions for interoperability - let function_name = match function_name.as_str() { - "substr" => "substring".to_string(), - _ => function_name, - }; - - // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, - // a plan-relative identifier starting from 0 is used as the function_anchor. - // The consumer is responsible for correctly registering - // mapping info stored in the extensions by the producer. - let function_anchor = match function_set.get(&function_name) { - Some(function_anchor) => { - // Function has been registered - *function_anchor - } - None => { - // Function has NOT been registered - let function_anchor = function_set.len() as u32; - function_set.insert(function_name.clone(), function_anchor); - - let function_extension = ExtensionFunction { - extension_uri_reference: u32::MAX, - function_anchor, - name: function_name, - }; - let simple_extension = extensions::SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionFunction(function_extension)), - }; - function_extensions.push(simple_extension); - function_anchor - } - }; - - // Return function anchor - function_anchor -} - /// Return Substrait scalar function with two arguments #[allow(deprecated)] pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Expression { - let function_anchor = - register_function(operator_to_name(op).to_string(), extension_info); + let function_anchor = extensions.register_function(operator_to_name(op).to_string()); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -994,7 +906,7 @@ pub fn make_binary_op_scalar_func( /// /// * `expr` - DataFusion expression to be parse into a Substrait expression /// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. /// This should only be set by caller with more than one input relations i.e. Join. /// Substrait expects one set of indices when joining two relations. /// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` @@ -1010,17 +922,14 @@ pub fn make_binary_op_scalar_func( /// `col_ref(1) = col_ref(3 + 0)` /// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index /// of the join key column from `right` -/// * `extension_info` - Substrait extension info. Contains registered function information +/// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::InList(InList { @@ -1030,10 +939,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1043,8 +952,7 @@ pub fn to_substrait_rex( }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1070,13 +978,12 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } - let function_anchor = - register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1096,58 +1003,58 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_low, Operator::Lt, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_high, &substrait_expr, Operator::Lt, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::Or, - extension_info, + extensions, )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, &substrait_expr, Operator::LtEq, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_high, Operator::LtEq, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::And, - extension_info, + extensions, )) } } @@ -1156,10 +1063,10 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; - Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } Expr::Case(Case { expr, @@ -1176,7 +1083,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?), then: None, }); @@ -1189,14 +1096,14 @@ pub fn to_substrait_rex( r#if, schema, col_ref_offset, - extension_info, + extensions, )?), then: Some(to_substrait_rex( ctx, then, schema, col_ref_offset, - extension_info, + extensions, )?), }); } @@ -1208,7 +1115,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?)), None => None, }; @@ -1221,22 +1128,22 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), + r#type: Some(to_substrait_type(data_type, true, extensions)?), input: Some(Box::new(to_substrait_rex( ctx, expr, schema, col_ref_offset, - extension_info, + extensions, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED }, ))), }) } - Expr::Literal(value) => to_substrait_literal_expr(value), + Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1247,7 +1154,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1257,19 +1164,19 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1298,7 +1205,7 @@ pub fn to_substrait_rex( *escape_char, schema, col_ref_offset, - extension_info, + extensions, ), Expr::InSubquery(InSubquery { expr, @@ -1306,10 +1213,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1324,8 +1231,7 @@ pub fn to_substrait_rex( }))), }; if *negated { - let function_anchor = - register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1348,7 +1254,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1356,7 +1262,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1364,7 +1270,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1372,7 +1278,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1380,7 +1286,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1388,7 +1294,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1396,7 +1302,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1404,7 +1310,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1412,7 +1318,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1420,7 +1326,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), _ => { not_impl_err!("Unsupported expression: {expr:?}") @@ -1428,7 +1334,11 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { +fn to_substrait_type( + dt: &DataType, + nullable: bool, + extensions: &mut Extensions, +) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { @@ -1548,7 +1458,9 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1599,7 +1512,8 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = + to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1613,10 +1527,12 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result Result { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .map(|field| { + to_substrait_type(field.data_type(), field.is_nullable(), extensions) + }) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -1700,21 +1618,19 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let function_anchor = if ignore_case { - register_function("ilike".to_string(), extension_info) + extensions.register_function("ilike".to_string()) } else { - register_function("like".to_string(), extension_info) + extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( - escape_char.map(|c| c.to_string()), - ))?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let escape_char = to_substrait_literal_expr( + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + extensions, + )?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1738,7 +1654,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1870,7 +1786,10 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { if value.is_null() { return Ok(Literal { nullable: true, @@ -1878,6 +1797,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, + extensions, )?)), }); } @@ -1949,14 +1869,15 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { let bytes = i.to_byte_slice(); ( LiteralType::UserDefined(UserDefined { - type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF, + type_reference: extensions + .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), type_parameters: vec![], val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(), + type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), value: bytes.to_vec().into(), })), }), - INTERVAL_MONTH_DAY_NANO_TYPE_REF, + DEFAULT_TYPE_VARIATION_REF, ) } ScalarValue::IntervalDayTime(Some(i)) => ( @@ -1996,20 +1917,65 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { DECIMAL_128_TYPE_VARIATION_REF, ), ScalarValue::List(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(l)?, + convert_array_to_literal_list(l, extensions)?, LARGE_CONTAINER_TYPE_VARIATION_REF, ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable(), extensions)?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.keys(), i)?, + extensions, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.values(), i)?, + extensions, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } ScalarValue::Struct(s) => ( LiteralType::Struct(Struct { fields: s .columns() .iter() .map(|col| { - to_substrait_literal(&ScalarValue::try_from_array(col, 0)?) + to_substrait_literal( + &ScalarValue::try_from_array(col, 0)?, + extensions, + ) }) .collect::>>()?, }), @@ -2030,29 +1996,42 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { fn convert_array_to_literal_list( array: &GenericListArray, + extensions: &mut Extensions, ) -> Result { assert_eq!(array.len(), 1); let nested_array = array.value(0); let values = (0..nested_array.len()) - .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&nested_array, i)?, + extensions, + ) + }) .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type(), array.is_nullable())? { + let lt = match to_substrait_type( + array.data_type(), + array.is_nullable(), + extensions, + )? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), _ => unreachable!(), }; - Ok(LiteralType::EmptyList(et)) + Ok(LiteralType::EmptyList(lt)) } else { Ok(LiteralType::List(List { values })) } } -fn to_substrait_literal_expr(value: &ScalarValue) -> Result { - let literal = to_substrait_literal(value)?; +fn to_substrait_literal_expr( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { + let literal = to_substrait_literal(value, extensions)?; Ok(Expression { rex_type: Some(RexType::Literal(literal)), }) @@ -2065,14 +2044,10 @@ fn to_substrait_unary_scalar_fn( arg: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - let function_anchor = register_function(fn_name.to_string(), extension_info); - let substrait_expr = - to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + let function_anchor = extensions.register_function(fn_name.to_string()); + let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2116,10 +2091,7 @@ fn substrait_sort_field( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::Sort(Sort { @@ -2127,7 +2099,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2161,15 +2133,17 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { + use super::*; use crate::logical_plan::consumer::{ from_substrait_literal_without_names, from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion::arrow::array::GenericListArray; + use datafusion::arrow::array::{ + GenericListArray, Int64Builder, MapBuilder, StringBuilder, + }; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; - - use super::*; + use std::collections::HashMap; #[test] fn round_trip_literals() -> Result<()> { @@ -2232,6 +2206,28 @@ mod test { ), )))?; + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + let c0 = Field::new("c0", DataType::Boolean, true); let c1 = Field::new("c1", DataType::Int32, true); let c2 = Field::new("c2", DataType::Utf8, true); @@ -2258,9 +2254,44 @@ mod test { fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait_literal = to_substrait_literal(&scalar)?; - let roundtrip_scalar = from_substrait_literal_without_names(&substrait_literal)?; + let mut extensions = Extensions::default(); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; + assert_eq!(scalar, roundtrip_scalar); + Ok(()) + } + + #[test] + fn custom_type_literal_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( + 17, 25, 1234567890, + ))); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; assert_eq!(scalar, roundtrip_scalar); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!(from_substrait_literal_without_names( + &substrait_literal, + &Extensions::default() + ) + .is_err()); Ok(()) } @@ -2329,11 +2360,44 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); + let mut extensions = Extensions::default(); + // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait)?; + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; assert_eq!(dt, roundtrip_dt); Ok(()) } + + #[test] + fn custom_type_extensions() -> Result<()> { + let mut extensions = Extensions::default(); + // IntervalMonthDayNano is represented as a custom type in Substrait + let dt = DataType::Interval(IntervalUnit::MonthDayNano); + + let substrait = to_substrait_type(&dt, true, &mut extensions)?; + let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; + assert_eq!(dt, roundtrip_dt); + + assert_eq!( + extensions, + Extensions { + functions: HashMap::new(), + types: HashMap::from([( + 0, + INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() + )]), + type_variations: HashMap::new(), + } + ); + + // Check we fail if we don't propagate extensions + assert!( + from_substrait_type_without_names(&substrait, &Extensions::default()) + .is_err() + ); + + Ok(()) + } } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 27f4b3ea228a..c94ad2d669fd 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -25,13 +25,16 @@ //! - Default type reference is 0. It is used when the actual type is the same with the original type. //! - Extended variant type references start from 1, and ususlly increase by 1. //! -//! Definitions here are not the final form. All the non-system-preferred variations will be defined +//! TODO: Definitions here are not the final form. All the non-system-preferred variations will be defined //! using [simple extensions] as per the [spec of type_variations](https://substrait.io/types/type_variations/) +//! //! //! [simple extensions]: (https://substrait.io/extensions/#simple-extensions) // For [type variations](https://substrait.io/types/type_variations/#type-variations) in substrait. // Type variations are used to represent different types based on one type class. +// TODO: Define as extensions: + /// The "system-preferred" variation (i.e., no variation). pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; @@ -55,6 +58,7 @@ pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth /// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalYear` type instead")] pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. @@ -68,6 +72,7 @@ pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime /// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalDay` type instead")] pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. @@ -82,21 +87,14 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano +#[deprecated( + since = "41.0.0", + note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; -// For User Defined URLs -/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth -pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month"; -/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. -/// -/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval -/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime -pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time"; /// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano -pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano"; +pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index a7653e11d598..439e3efa2922 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -38,7 +38,10 @@ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLI use datafusion::prelude::*; use datafusion::execution::session_state::SessionStateBuilder; -use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionType, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -175,15 +178,46 @@ async fn select_with_filter() -> Result<()> { #[tokio::test] async fn select_with_reused_functions() -> Result<()> { + let ctx = create_context().await?; let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; - roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; - function_names.sort(); - function_anchors.sort(); + let proto = roundtrip_with_ctx(sql, ctx).await?; + let mut functions = proto + .extensions + .iter() + .map(|e| match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => { + (ext_f.function_anchor, ext_f.name.to_owned()) + } + _ => unreachable!("Non-function extensions not expected"), + }) + .collect::>(); + functions.sort_by_key(|(anchor, _)| *anchor); + + // Functions are encountered (and thus registered) depth-first + let expected = vec![ + (0, "gt".to_string()), + (1, "lt".to_string()), + (2, "and".to_string()), + ]; + assert_eq!(functions, expected); - assert_eq!(function_names, ["and", "gt", "lt"]); - assert_eq!(function_anchors, [0, 1, 2]); + Ok(()) +} +#[tokio::test] +async fn roundtrip_udt_extensions() -> Result<()> { + let ctx = create_context().await?; + let proto = + roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) + .await?; + let expected_type = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(ExtensionType { + extension_uri_reference: u32::MAX, + type_anchor: 0, + name: "interval-month-day-nano".to_string(), + })), + }; + assert_eq!(proto.extensions, vec![expected_type]); Ok(()) } @@ -715,7 +749,7 @@ async fn roundtrip_values() -> Result<()> { [[-213.1, NULL, 5.5, 2.0, 1.0], []], \ arrow_cast([1,2,3], 'LargeList(Int64)'), \ STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \ - [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\ + [STRUCT(STRUCT('a' AS string_field) AS struct_field), STRUCT(STRUCT('b' AS string_field) AS struct_field)]\ ), \ (NULL, NULL, NULL, NULL, NULL, NULL)", "Values: \ @@ -725,7 +759,7 @@ async fn roundtrip_values() -> Result<()> { List([[-213.1, , 5.5, 2.0, 1.0], []]), \ LargeList([1, 2, 3]), \ Struct({c0:true,int_field:1,c2:}), \ - List([{struct_field: {string_field: a}}])\ + List([{struct_field: {string_field: a}}, {struct_field: {string_field: b}}])\ ), \ (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", true).await @@ -858,7 +892,8 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -891,7 +926,8 @@ async fn roundtrip_window_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udwf(dummy_agg); - roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await + roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -1083,7 +1119,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -1102,56 +1138,25 @@ async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { assert_eq!(plan.schema(), plan2.schema()); DataFrame::new(ctx.state(), plan2).show().await?; - Ok(()) + Ok(proto) } async fn roundtrip(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_context().await?).await + roundtrip_with_ctx(sql, create_context().await?).await?; + Ok(()) } async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - - assert_eq!(plan.schema(), plan2.schema()); + let proto = roundtrip_with_ctx(sql, ctx).await?; // verify that the join filters are None verify_post_join_filter_value(proto).await } async fn roundtrip_all_types(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_all_type_context().await?).await -} - -async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { - let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - - let mut function_names: Vec = vec![]; - let mut function_anchors: Vec = vec![]; - for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { - MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension"), - }; - function_names.push(function_name.to_string()); - function_anchors.push(function_anchor); - } - - Ok((function_names, function_anchors)) + roundtrip_with_ctx(sql, create_all_type_context().await?).await?; + Ok(()) } async fn create_context() -> Result { diff --git a/docs/source/index.rst b/docs/source/index.rst index ca6905c434f3..9c8c886d2502 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,7 +107,7 @@ To get started, see library-user-guide/custom-table-providers library-user-guide/extending-operators library-user-guide/profiling - + library-user-guide/query-optimizer .. _toc.contributor-guide: .. toctree:: diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md new file mode 100644 index 000000000000..5aacfaf59cb1 --- /dev/null +++ b/docs/source/library-user-guide/query-optimizer.md @@ -0,0 +1,336 @@ + + +# DataFusion Query Optimizer + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory +format. + +DataFusion has modular design, allowing individual crates to be re-used in other projects. + +This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and +contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so +they execute more quickly while still computing the same result. + +## Running the Optimizer + +The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules +and applying it to a logical plan to produce an optimized logical plan. + +```rust + +// We need a logical plan as the starting point. There are many ways to build a logical plan: +// +// The `datafusion-expr` crate provides a LogicalPlanBuilder +// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL +// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan +let logical_plan = ... + +let mut config = OptimizerContext::default(); +let optimizer = Optimizer::new(&config); +let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; + +fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { + println!( + "After applying rule '{}':\n{}", + rule.name(), + plan.display_indent() + ) +} +``` + +## Providing Custom Rules + +The optimizer can be created with a custom set of rules. + +```rust +let optimizer = Optimizer::with_rules(vec![ + Arc::new(MyRule {}) +]); +``` + +## Writing Optimization Rules + +Please refer to the +[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) +example to learn more about the general approach to writing optimizer rules and +then move onto studying the existing rules. + +All rules must implement the `OptimizerRule` trait. + +```rust +/// `OptimizerRule` transforms one ['LogicalPlan'] into another which +/// computes the same results, but in a potentially more efficient +/// way. If there are no suitable transformations for the input plan, +/// the optimizer can simply return it as is. +pub trait OptimizerRule { + /// Rewrite `plan` to an optimized form + fn optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result; + + /// A human readable name for this optimizer rule + fn name(&self) -> &str; +} +``` + +### General Guidelines + +Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate +individual operators or expressions. + +Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs +the actual optimization. This approach is used in projection push down and filter push down. + +### Expression Naming + +Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output +contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: + +```text +> select count(c9) from aggregate_test_100; ++------------------------------+ +| COUNT(aggregate_test_100.c9) | ++------------------------------+ +| 100 | ++------------------------------+ +``` + +These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan +to another. For example: + +```text +> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; ++--------------------------------------------+ +| sq.COUNT(aggregate_test_100.c9) + Int64(1) | ++--------------------------------------------+ +| 101 | ++--------------------------------------------+ +``` + +### Implication + +Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are +not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten +expression by adding an alias. + +Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be +displayed the same as `1 + 2`: + +```text +> select 1 + 2; ++---------------------+ +| Int64(1) + Int64(2) | ++---------------------+ +| 3 | ++---------------------+ +``` + +Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively +`3 as "1 + 2"`: + +```text +> explain select 1 + 2; ++---------------+-------------------------------------------------+ +| plan_type | plan | ++---------------+-------------------------------------------------+ +| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | +| | EmptyRelation | +| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | +| | PlaceholderRowExec | +| | | ++---------------+-------------------------------------------------+ +``` + +If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) +and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. + +### Building Expression Names + +There are currently two ways to create a name for an expression in the logical plan. + +```rust +impl Expr { + /// Returns the name of this expression as it should appear in a schema. This name + /// will not include any CAST expressions. + pub fn display_name(&self) -> Result { + create_name(self) + } + + /// Returns a full and complete string representation of this expression. + pub fn canonical_name(&self) -> String { + format!("{}", self) + } +} +``` + +When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a +name to be used in a schema, `display_name` should be used. + +### Utilities + +There are a number of utility methods provided that take care of some common tasks. + +### ExprVisitor + +The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. + +Here is an example that demonstrates this. + +```rust +fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { + struct InSubqueryVisitor<'a> { + accum: &'a mut Vec, + } + + impl ExpressionVisitor for InSubqueryVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + if let Expr::InSubquery(_) = expr { + self.accum.push(expr.to_owned()); + } + Ok(Recursion::Continue(self)) + } + } + + expression.accept(InSubqueryVisitor { accum: extracted })?; + Ok(()) +} +``` + +### Rewriting Expressions + +The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied +to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). + +The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, +consuming `self` producing a new expression. + +```rust +let mut expr_rewriter = MyExprRewriter {}; +let expr = expr.rewrite(&mut expr_rewriter)?; +``` + +Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the +implementation does not need to perform any recursion since this is handled by the `rewrite` method. + +```rust +struct MyExprRewriter {} + +impl ExprRewriter for MyExprRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Between { + negated, + expr, + low, + high, + } => { + let expr: Expr = expr.as_ref().clone(); + let low: Expr = low.as_ref().clone(); + let high: Expr = high.as_ref().clone(); + if negated { + Ok(expr.clone().lt(low).or(expr.clone().gt(high))) + } else { + Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) + } + } + _ => Ok(expr.clone()), + } + } +} +``` + +### optimize_children + +Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate +that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on +the plan's children and then returns a node of the same type. + +```rust +fn optimize( + &self, + plan: &LogicalPlan, + _config: &mut OptimizerConfig, +) -> Result { + // recurse down and optimize children first + let plan = utils::optimize_children(self, plan, _config)?; + + ... +} +``` + +### Writing Tests + +There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan +in isolation (without any other rule being applied). + +There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. + +### Debugging + +The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. + +In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. + +```text +> explain verbose select cast(1 + 2.2 as string) as foo; ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| plan_type | plan | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | +| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | +| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | +| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | +| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | +| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | +| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | +| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | +| logical_plan after projection_push_down | SAME TEXT AS ABOVE | +| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | +| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | +| logical_plan after filter_push_down | SAME TEXT AS ABOVE | +| logical_plan after limit_push_down | SAME TEXT AS ABOVE | +| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | +| logical_plan | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | +| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | +| physical_plan after join_selection | SAME TEXT AS ABOVE | +| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | +| physical_plan after repartition | SAME TEXT AS ABOVE | +| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | +| physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +``` + +[df]: https://crates.io/crates/datafusion diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 5130b0a56d0e..e4b849cd28bb 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -44,37 +44,38 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | | datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | | datafusion.catalog.has_header | false | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | | datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | | datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | | datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | | datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | | datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.compression | zstd(3) | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 40.0.0 | Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | -| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_on_read | true | Use any available bloom filters when reading parquet files | -| datafusion.execution.parquet.bloom_filter_on_write | false | Write bloom filters for all columns when creating parquet files | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | NULL | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 40.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | +| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | | datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d2e012cf4093..561824772af8 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -772,7 +772,7 @@ concat(str[, ..., str_n]) Concatenates multiple strings together with a specified separator. ``` -concat(separator, str[, ..., str_n]) +concat_ws(separator, str[, ..., str_n]) ``` #### Arguments